diff --git a/mlir/.clang-format b/mlir/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..392e20189554b6d594482cd341550576e98d32a8 --- /dev/null +++ b/mlir/.clang-format @@ -0,0 +1,2 @@ +BasedOnStyle: LLVM +AlwaysBreakTemplateDeclarations: Yes \ No newline at end of file diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..67d1f00322c5c3078f76e6ca752004a06dcb615f --- /dev/null +++ b/mlir/CMakeLists.txt @@ -0,0 +1,108 @@ +# MLIR project. +set(MLIR_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include ) # --src-root +set(MLIR_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/include ) # --includedir +set(MLIR_TABLEGEN_EXE mlir-tblgen) + +set(MLIR_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +function(mlir_tablegen ofn) + tablegen(MLIR ${ARGV} "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}") + set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} + PARENT_SCOPE) +endfunction() + +function(add_mlir_dialect dialect dialect_doc_filename) + set(LLVM_TARGET_DEFINITIONS ${dialect}.td) + mlir_tablegen(${dialect}.h.inc -gen-op-decls) + mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) + add_public_tablegen_target(MLIR${dialect}IncGen) + + # Generate Dialect Documentation + set(LLVM_TARGET_DEFINITIONS ${dialect_doc_filename}.td) + tablegen(MLIR ${dialect_doc_filename}.md -gen-op-doc "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}") + set(GEN_DOC_FILE ${MLIR_BINARY_DIR}/docs/Dialects/${dialect_doc_filename}.md) + add_custom_command( + OUTPUT ${GEN_DOC_FILE} + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/${dialect_doc_filename}.md + ${GEN_DOC_FILE} + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${dialect_doc_filename}.md) + add_custom_target(${dialect_doc_filename}DocGen DEPENDS ${GEN_DOC_FILE}) + add_dependencies(mlir-doc ${dialect_doc_filename}DocGen) +endfunction() + +add_custom_target(mlir-doc) + +# TODO: This is to handle the current static registration, but should be +# factored out a bit. +function(whole_archive_link target) + if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin") + set(link_flags "-L${CMAKE_BINARY_DIR}/lib ") + FOREACH(LIB ${ARGN}) + string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${CMAKE_BINARY_DIR}/lib/lib${LIB}.a ") + ENDFOREACH(LIB) + elseif(MSVC) + FOREACH(LIB ${ARGN}) + string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ") + ENDFOREACH(LIB) + else() + set(link_flags "-L${CMAKE_BINARY_DIR}/lib -Wl,--whole-archive,") + FOREACH(LIB ${ARGN}) + string(CONCAT link_flags ${link_flags} "-l${LIB},") + ENDFOREACH(LIB) + string(CONCAT link_flags ${link_flags} "--no-whole-archive") + endif() + set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags}) +endfunction(whole_archive_link) + +# Build the CUDA conversions and run according tests if the NVPTX backend +# is available +if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD) + set(MLIR_CUDA_CONVERSIONS_ENABLED 1) +else() + set(MLIR_CUDA_CONVERSIONS_ENABLED 0) +endif() + +set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") + +include_directories( "include") +include_directories( ${MLIR_INCLUDE_DIR}) + +add_subdirectory(include/mlir) +add_subdirectory(lib) +add_subdirectory(tools) +add_subdirectory(unittests) +add_subdirectory(test) + +if( LLVM_INCLUDE_EXAMPLES ) + add_subdirectory(examples) +endif() + +if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY) + install(DIRECTORY include/mlir include/mlir-c + DESTINATION include + COMPONENT mlir-headers + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.inc" + PATTERN "LICENSE.TXT" + ) + + install(DIRECTORY ${MLIR_INCLUDE_DIR}/mlir ${MLIR_INCLUDE_DIR}/mlir-c + DESTINATION include + COMPONENT mlir-headers + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.gen" + PATTERN "*.inc" + PATTERN "CMakeFiles" EXCLUDE + PATTERN "config.h" EXCLUDE + ) + + if (NOT LLVM_ENABLE_IDE) + add_llvm_install_targets(install-mlir-headers + DEPENDS mlir-headers + COMPONENT mlir-headers) + endif() +endif() diff --git a/mlir/README.md b/mlir/README.md new file mode 100644 index 0000000000000000000000000000000000000000..779269bc57419c0c57f2351ef0115bb2ed6c06dc --- /dev/null +++ b/mlir/README.md @@ -0,0 +1,3 @@ +# Multi-Level Intermediate Representation + +See [https://mlir.llvm.org/](https://mlir.llvm.org/]) for more information. diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md new file mode 100644 index 0000000000000000000000000000000000000000..642717faa737c51cec1a00e5ac0d8845a419b9d5 --- /dev/null +++ b/mlir/docs/Canonicalization.md @@ -0,0 +1,64 @@ +# Operation Canonicalization in MLIR + +Canonicalization is an important part of compiler IR design: it makes it easier +to implement reliable compiler transformations and to reason about what is +better or worse in the code, and it forces interesting discussions about the +goals of a particular level of IR. Dan Gohman wrote +[an article](https://sunfishcode.github.io/blog/2018/10/22/Canonicalization.html) +exploring these issues; it is worth reading if you're not familiar with these +concepts. + +Most compilers have canonicalization passes, and sometimes they have many +different ones (e.g. instcombine, dag combine, etc in LLVM). Because MLIR is a +multi-level IR, we can provide a single canonicalization infrastructure and +reuse it across many different IRs that it represents. This document describes +the general approach, global canonicalizations performed, and provides sections +to capture IR-specific rules for reference. + +## General Design + +MLIR has a single canonicalization pass, which iteratively applies +canonicalization transformations in a greedy way until the IR converges. These +transformations are defined by the operations themselves, which allows each +dialect to define its own set of operations and canonicalizations together. + +Some important things to think about w.r.t. canonicalization patterns: + +* Repeated applications of patterns should converge. Unstable or cyclic + rewrites will cause infinite loops in the canonicalizer. + +* It is generally better to canonicalize towards operations that have fewer + uses of a value when the operands are duplicated, because some patterns only + match when a value has a single user. For example, it is generally good to + canonicalize "x + x" into "x * 2", because this reduces the number of uses + of x by one. + +* It is always good to eliminate operations entirely when possible, e.g. by + folding known identities (like "x + 0 = x"). + +## Globally Applied Rules + +These transformations are applied to all levels of IR: + +* Elimination of operations that have no side effects and have no uses. + +* Constant folding - e.g. "(addi 1, 2)" to "3". Constant folding hooks are + specified by operations. + +* Move constant operands to commutative binary operators to the right side - + e.g. "(addi 4, x)" to "(addi x, 4)". + +## Builtin Ops Canonicalizations + +These transformations are applied to builtin ops: + +* `constant` ops are uniqued and hoisted into the entry block of the first + parent region that is isolated from above, e.g. the entry block of a + function. +* (TODO) Merge `affine.apply` operations that directly feed each other. + +## Standard Ops Canonicalizations + +* Shape folding of `alloc` operations to turn dynamic dimensions into static + ones. +* Folding `memref_cast` operations into users where possible. diff --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md new file mode 100644 index 0000000000000000000000000000000000000000..19403e27dc4ce549d3dfc4f63312a78420659db2 --- /dev/null +++ b/mlir/docs/ConversionToLLVMDialect.md @@ -0,0 +1,443 @@ +# Conversion to the LLVM Dialect + +Conversion from the Standard to the [LLVM Dialect](Dialects/LLVM.md) can be +performed by the specialized dialect conversion pass by running + +```sh +mlir-opt -convert-std-to-llvm +``` + +It performs type and operation conversions for a subset of operations from +standard dialect (operations on scalars and vectors, control flow operations) as +described in this document. We use the terminology defined by the +[LLVM IR Dialect description](Dialects/LLVM.md) throughout this document. + +[TOC] + +## Type Conversion + +### Scalar Types + +Scalar types are converted to their LLVM counterparts if they exist. The +following conversions are currently implemented. + +- `i*` converts to `!llvm.i*` +- `f16` converts to `!llvm.half` +- `f32` converts to `!llvm.float` +- `f64` converts to `!llvm.double` + +Note: `bf16` type is not supported by LLVM IR and cannot be converted. + +### Index Type + +Index type is converted to a wrapped LLVM IR integer with bitwidth equal to the +bitwidth of the pointer size as specified by the +[data layout](https://llvm.org/docs/LangRef.html#data-layout) of the LLVM module +[contained](Dialects/LLVM.md#context-and-module-association) in the LLVM Dialect +object. For example, on x86-64 CPUs it converts to `!llvm.i64`. + +### Vector Types + +LLVM IR only supports *one-dimensional* vectors, unlike MLIR where vectors can +be multi-dimensional. Vector types cannot be nested in either IR. In the +one-dimensional case, MLIR vectors are converted to LLVM IR vectors of the same +size with element type converted using these conversion rules. In the +n-dimensional case, MLIR vectors are converted to (n-1)-dimensional array types +of one-dimensional vectors. + +For example, `vector<4 x f32>` converts to `!llvm<"<4 x float>">` and `vector<4 +x 8 x 16 x f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. + +### Memref Types + +Memref types in MLIR have both static and dynamic information associated with +them. The dynamic information comprises the buffer pointer as well as sizes and +strides of any dynamically sized dimensions. Memref types are normalized and +converted to a descriptor that is only dependent on the rank of the memref. The +descriptor contains: + +1. the pointer to the data buffer, followed by +2. the pointer to properly aligned data payload that the memref indexes, + followed by +3. a lowered `index`-type integer containing the distance between the beginning + of the buffer and the first element to be accessed through the memref, + followed by +4. an array containing as many `index`-type integers as the rank of the memref: + the array represents the size, in number of elements, of the memref along + the given dimension. For constant MemRef dimensions, the corresponding size + entry is a constant whose runtime value must match the static value, + followed by +5. a second array containing as many 64-bit integers as the rank of the MemRef: + the second array represents the "stride" (in tensor abstraction sense), i.e. + the number of consecutive elements of the underlying buffer. + +For constant memref dimensions, the corresponding size entry is a constant whose +runtime value matches the static value. This normalization serves as an ABI for +the memref type to interoperate with externally linked functions. In the +particular case of rank `0` memrefs, the size and stride arrays are omitted, +resulting in a struct containing two pointers + offset. + +Examples: + +```mlir +memref -> !llvm<"{ float*, float*, i64 }"> +memref<1 x f32> -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +memref -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +memref<10x42x42x43x123 x f32> -> !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +memref<10x?x42x?x123 x f32> -> !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> + +// Memref types can have vectors as element types +memref<1x? x vector<4xf32>> -> !llvm<"{ <4 x float>*, <4 x float>*, i64, [1 x i64], [1 x i64] }"> +``` + +If the rank of the memref is unknown at compile time, the Memref is converted to +an unranked descriptor that contains: + +1. a 64-bit integer representing the dynamic rank of the memref, followed by +2. a pointer to a ranked memref descriptor with the contents listed above. + +Dynamic ranked memrefs should be used only to pass arguments to external library +calls that expect a unified memref type. The called functions can parse any +unranked memref descriptor by reading the rank and parsing the enclosed ranked +descriptor pointer. + +Examples: + +```mlir +// unranked descriptor +memref<*xf32> -> !llvm<"{i64, i8*}"> +``` + +**In function signatures,** `memref` is passed as a _pointer_ to the structured +defined above to comply with the calling convention. + +Example: + +```mlir +// A function type with memref as argument +(memref) -> () +// is transformed into the LLVM function with pointer-to-structure argument. +!llvm<"void({ float*, float*, i64, [1 x i64], [1 x i64]}*) "> +``` + +### Function Types + +Function types get converted to LLVM function types. The arguments are converted +individually according to these rules. The result types need to accommodate the +fact that LLVM IR functions always have a return type, which may be a Void type. +The converted function always has a single result type. If the original function +type had no results, the converted function will have one result of the wrapped +`void` type. If the original function type had one result, the converted +function will have one result converted using these rules. Otherwise, the result +type will be a wrapped LLVM IR structure type where each element of the +structure corresponds to one of the results of the original function, converted +using these rules. In high-order functions, function-typed arguments and results +are converted to a wrapped LLVM IR function pointer type (since LLVM IR does not +allow passing functions to functions without indirection) with the pointee type +converted using these rules. + +Examples: + +```mlir +// zero-ary function type with no results. +() -> () +// is converted to a zero-ary function with `void` result +!llvm<"void ()"> + +// unary function with one result +(i32) -> (i64) +// has its argument and result type converted, before creating the LLVM IR function type +!llvm<"i64 (i32)"> + +// binary function with one result +(i32, f32) -> (i64) +// has its arguments handled separately +!llvm<"i64 (i32, float)"> + +// binary function with two results +(i32, f32) -> (i64, f64) +// has its result aggregated into a structure type +!llvm<"{i64, double} (i32, f32)"> + +// function-typed arguments or results in higher-order functions +(() -> ()) -> (() -> ()) +// are converted into pointers to functions +!llvm<"void ()* (void ()*)"> +``` + +## Calling Convention + +### Function Signature Conversion + +LLVM IR functions are defined by a custom operation. The function itself has a +wrapped LLVM IR function type converted as described above. The function +definition operation uses MLIR syntax. + +Examples: + +```mlir +// zero-ary function type with no results. +func @foo() -> () +// gets LLVM type void(). +llvm.func @foo() -> () + +// function with one result +func @bar(i32) -> (i64) +// gets converted to LLVM type i64(i32). +func @bar(!llvm.i32) -> !llvm.i64 + +// function with two results +func @qux(i32, f32) -> (i64, f64) +// has its result aggregated into a structure type +func @qux(!llvm.i32, !llvm.float) -> !llvm<"{i64, double}"> + +// function-typed arguments or results in higher-order functions +func @quux(() -> ()) -> (() -> ()) +// are converted into pointers to functions +func @quux(!llvm<"void ()*">) -> !llvm<"void ()*"> +// the call flow is handled by the LLVM dialect `call` operation supporting both +// direct and indirect calls +``` + +### Result Packing + +In case of multi-result functions, the returned values are inserted into a +structure-typed value before being returned and extracted from it at the call +site. This transformation is a part of the conversion and is transparent to the +defines and uses of the values being returned. + +Example: + +```mlir +func @foo(%arg0: i32, %arg1: i64) -> (i32, i64) { + return %arg0, %arg1 : i32, i64 +} +func @bar() { + %0 = constant 42 : i32 + %1 = constant 17 : i64 + %2:2 = call @foo(%0, %1) : (i32, i64) -> (i32, i64) + "use_i32"(%2#0) : (i32) -> () + "use_i64"(%2#1) : (i64) -> () +} + +// is transformed into + +func @foo(%arg0: !llvm.i32, %arg1: !llvm.i64) -> !llvm<"{i32, i64}"> { + // insert the vales into a structure + %0 = llvm.mlir.undef : !llvm<"{i32, i64}"> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{i32, i64}"> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{i32, i64}"> + + // return the structure value + llvm.return %2 : !llvm<"{i32, i64}"> +} +func @bar() { + %0 = llvm.mlir.constant(42 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(17) : !llvm.i64 + + // call and extract the values from the structure + %2 = llvm.call @bar(%0, %1) : (%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm<"{i32, i64}"> + %3 = llvm.extractvalue %2[0] : !llvm<"{i32, i64}"> + %4 = llvm.extractvalue %2[1] : !llvm<"{i32, i64}"> + + // use as before + "use_i32"(%3) : (!llvm.i32) -> () + "use_i64"(%4) : (!llvm.i64) -> () +} +``` + +### Calling Convention for `memref` + +For function _arguments_ of `memref` type, ranked or unranked, the type of the +argument is a _pointer_ to the memref descriptor type defined above. The caller +of such function is required to store the descriptor in memory and guarantee +that the storage remains live until the callee returns. The caller can than pass +the pointer to that memory as function argument. The callee loads from the +pointers it was passed as arguments in the entry block of the function, making +the descriptor passed in as argument available for use similarly to +ocally-defined descriptors. + +This convention is implemented in the conversion of `std.func` and `std.call` to +the LLVM dialect. Conversions from other dialects should take it into account. +The motivation for this convention is to simplify the ABI for interfacing with +other LLVM modules, in particular those generated from C sources, while avoiding +platform-specific aspects until MLIR has a proper ABI modeling. + +Example: + +```mlir + +func @foo(memref) -> () { + %c0 = constant 0 : index + load %arg0[%c0] : memref + return +} + +func @bar(%arg0: index) { + %0 = alloc(%arg0) : memref + call @foo(%0) : (memref)-> () + return +} + +// Gets converted to the following IR. +// Accepts a pointer to the memref descriptor. +llvm.func @foo(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) { + // Loads the descriptor so that it can be used similarly to locally + // created descriptors. + %0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> +} + +llvm.func @bar(%arg0: !llvm.i64) { + // ... Allocation ... + // Definition of the descriptor. + %7 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + // ... Filling in the descriptor ... + %14 = // The final value of the allocated descriptor. + // Allocate the memory for the descriptor and store it. + %15 = llvm.mlir.constant(1 : index) : !llvm.i64 + %16 = llvm.alloca %15 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> + llvm.store %14, %16 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> + // Pass the pointer to the function. + llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> () + llvm.return +} +``` + +*This convention may or may not apply if the conversion of MemRef types is +overridden by the user.* + +## Repeated Successor Removal + +Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect +and the conversion procedure must account for the differences between block +arguments and LLVM IR PHI nodes. In particular, LLVM IR disallows PHI nodes with +different values coming from the same source. Therefore, the LLVM IR dialect +disallows operations that have identical successors accepting arguments, which +would lead to invalid PHI nodes. The conversion process resolves the potential +PHI source ambiguity by injecting dummy blocks if the same block is used more +than once as a successor in an instruction. These dummy blocks branch +unconditionally to the original successors, pass them the original operands +(available in the dummy block because it is dominated by the original block) and +are used instead of them in the original terminator operation. + +Example: + +```mlir + cond_br %0, ^bb1(%1 : i32), ^bb1(%2 : i32) +^bb1(%3 : i32) + "use"(%3) : (i32) -> () +``` + +leads to a new basic block being inserted, + +```mlir + cond_br %0, ^bb1(%1 : i32), ^dummy +^bb1(%3 : i32): + "use"(%3) : (i32) -> () +^dummy: + br ^bb1(%4 : i32) +``` + +before the conversion to the LLVM IR dialect: + +```mlir + llvm.cond_br %0, ^bb1(%1 : !llvm.i32), ^dummy +^bb1(%3 : !llvm<"i32">): + "use"(%3) : (!llvm.i32) -> () +^dummy: + llvm.br ^bb1(%2 : !llvm.i32) +``` + +## Default Memref Model + +### Memref Descriptor + +Within a converted function, a `memref`-typed value is represented by a memref +_descriptor_, the type of which is the structure type obtained by converting +from the memref type. This descriptor holds all the necessary information to +produce an address of a specific element. In particular, it holds dynamic values +for static sizes, and they are expected to match at all times. + +It is created by the allocation operation and is updated by the conversion +operations that may change static dimensions into dynamic and vice versa. + +**Note**: LLVM IR conversion does not support `memref`s with layouts that are +not amenable to the strided form. + +### Index Linearization + +Accesses to a memref element are transformed into an access to an element of the +buffer pointed to by the descriptor. The position of the element in the buffer +is calculated by linearizing memref indices in row-major order (lexically first +index is the slowest varying, similar to C, but accounting for strides). The +computation of the linear address is emitted as arithmetic operation in the LLVM +IR dialect. Strides are extracted from the memref descriptor. + +Accesses to zero-dimensional memref (that are interpreted as pointers to the +elemental type) are directly converted into `llvm.load` or `llvm.store` without +any pointer manipulations. + +Examples: + +An access to a zero-dimensional memref is converted into a plain load: + +```mlir +// before +%0 = load %m[] : memref + +// after +%0 = llvm.load %m : !llvm<"float*"> +``` + +An access to a memref with indices: + +```mlir +%0 = load %m[1,2,3,4] : memref<10x?x13x?xf32> +``` + +is transformed into the equivalent of the following code: + +```mlir +// Compute the linearized index from strides. Each block below extracts one +// stride from the descriptor, multipllies it with the index and accumulates +// the total offset. +%stride1 = llvm.extractvalue[4, 0] : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> +%idx1 = llvm.mlir.constant(1 : index) !llvm.i64 +%addr1 = muli %stride1, %idx1 : !llvm.i64 + +%stride2 = llvm.extractvalue[4, 1] : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> +%idx2 = llvm.mlir.constant(2 : index) !llvm.i64 +%addr2 = muli %stride2, %idx2 : !llvm.i64 +%addr3 = addi %addr1, %addr2 : !llvm.i64 + +%stride3 = llvm.extractvalue[4, 2] : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> +%idx3 = llvm.mlir.constant(3 : index) !llvm.i64 +%addr4 = muli %stride3, %idx3 : !llvm.i64 +%addr5 = addi %addr3, %addr4 : !llvm.i64 + +%stride4 = llvm.extractvalue[4, 3] : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> +%idx4 = llvm.mlir.constant(4 : index) !llvm.i64 +%addr6 = muli %stride4, %idx4 : !llvm.i64 +%addr7 = addi %addr5, %addr6 : !llvm.i64 + +// Add the linear offset to the address. +%offset = llvm.extractvalue[2] : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> +%addr8 = addi %addr7, %offset : !llvm.i64 + +// Obtain the aligned pointer. +%aligned = llvm.extractvalue[1] : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> + +// Get the address of the data pointer. +%ptr = llvm.getelementptr %aligned[%addr8] + : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> -> !llvm<"float*"> + +// Perform the actual load. +%0 = llvm.load %ptr : !llvm<"float*"> +``` + +For stores, the address computation code is identical and only the actual store +operation is different. + +Note: the conversion does not perform any sort of common subexpression +elimination when emitting memref accesses. diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md new file mode 100644 index 0000000000000000000000000000000000000000..67ff102fef968bc0eb8bedbb2082c0e62f4808ca --- /dev/null +++ b/mlir/docs/DeclarativeRewrites.md @@ -0,0 +1,690 @@ +# Table-driven Declarative Rewrite Rule (DRR) + +In addition to subclassing the `mlir::RewritePattern` C++ class, MLIR also +supports defining rewrite rules in a declarative manner. Similar to +[Op Definition Specification](OpDefinitions.md) (ODS), this is achieved via +[TableGen][TableGen], which is a language to maintain records of domain-specific +information. The rewrite rules are specified concisely in a TableGen record, +which will be expanded into an equivalent `mlir::RewritePattern` subclass at +compiler build time. + +This manual explains in detail all of the available mechanisms for defining +rewrite rules in such a declarative manner. It aims to be a specification +instead of a tutorial. Please refer to +[Quickstart tutorial to adding MLIR graph rewrite](QuickstartRewrites.md) for +the latter. + +Given that declarative rewrite rules depend on op definition specification, this +manual assumes knowledge of the [ODS](OpDefinitions.md) doc. + +## Benefits + +Compared to the hand-written C++ classes, this declarative approach has several +benefits, including but not limited to: + +* **Being declarative**: The pattern creator just needs to state the rewrite + pattern declaratively, without worrying about the concrete C++ methods to + call. +* **Removing boilerplate and showing the very essence of the rewrite**: + `mlir::RewritePattern` is already good at hiding boilerplate for defining a + rewrite rule. But we still need to write the class and function structures + required by the C++ programming language, inspect ops for matching, and call + op `build()` methods for constructing. These statements are typically quite + simple and similar, so they can be further condensed with auto-generation. + Because we reduce the boilerplate to the bare minimum, the declarative + rewrite rule will just contain the very essence of the rewrite. This makes + it very easy to understand the pattern. + +## Strengths and Limitations + +The declarative rewrite rule is **operation-based**: it describes a rule to +match against a directed acyclic graph (DAG) of operations and generate DAGs of +operations. This gives DRR both its strengths and limitations: it is good at +expressing op to op conversions, but not that well suited for, say, converting +an op into a loop nest. + +Per the current implementation, DRR does not have good support for the following +features: + +* Matching and generating ops with regions. +* Matching and generating ops with block arguments. +* Matching multi-result ops in nested patterns. +* Matching and generating variadic operand/result ops in nested patterns. +* Packing and unpacking variadic operands/results during generation. +* [`NativeCodeCall`](#native-code-call-transforming-the-generated-op) + returning more than one results. + +## Rule Definition + +The core construct for defining a rewrite rule is defined in +[`OpBase.td`][OpBase] as + +```tblgen +class Pattern< + dag sourcePattern, list resultPatterns, + list additionalConstraints = [], + dag benefitsAdded = (addBenefit 0)>; +``` + +A declarative rewrite rule contains two main components: + +* A _source pattern_, which is used for matching a DAG of operations. +* One or more _result patterns_, which are used for generating DAGs of + operations to replace the matched DAG of operations. + +We allow multiple result patterns to support +[multi-result ops](#supporting-multi-result-ops) and +[auxiliary ops](#supporting-auxiliary-ops), but frequently we just want to +convert one DAG of operations to another DAG of operations. There is a handy +wrapper of `Pattern`, `Pat`, which takes a single result pattern: + +```tblgen +class Pat< + dag sourcePattern, dag resultPattern, + list additionalConstraints = [], + dag benefitsAdded = (addBenefit 0)> : + Pattern; +``` + +Each pattern is specified as a TableGen `dag` object with the syntax of +`(operator arg0, arg1, ...)`. + +`operator` is typically an MLIR op, but it can also be other +[directives](#special-directives). `argN` is for matching (if used in source +pattern) or generating (if used in result pattern) the `N`-th argument for +`operator`. If the `operator` is some MLIR operation, it means the `N`-th +argument as specified in the `arguments` list of the op's definition. +Therefore, we say op argument specification in pattern is **position-based**: +the position where they appear matters. + +`argN` can be a `dag` object itself, thus we can have nested `dag` tree to model +the def-use relationship between ops. + +### Source pattern + +The source pattern is for matching a DAG of operations. Arguments in the `dag` +object are intended to **capture** the op arguments. They can also be used to +**further limit** the match criteria. The capturing is done by specifying a +symbol starting with the `$` sign, while further constraints are introduced by +specifying a `TypeConstraint` (for an operand) or a `AttrConstraint` (for an +attribute). + +#### Binding op arguments and limiting the match + +For example, + +```tblgen +def AOp : Op<"a_op"> { + let arguments = (ins + AnyType:$a_input, + AnyAttr:$a_attr + ); + + let results = (outs + AnyType:$a_output + ); +} + +def : Pat<(AOp $input, F32Attr:$attr), ...>; +``` + +In the above, we are matching an `AOp` whose `$input` can be anything valid as +defined by the op and whose `$attr` must be a float attribute. If the match +succeeds, we bind the `$input` symbol to the op's only input (`$a_input`) and +`$attr` to the only attribute (`$a_attr`); we can reference them using `$input` +and `$attr` in result patterns and additional constraints. + +The pattern is position-based: the symbol names used for capturing here do not +need to match with the op definition as shown in the above example. As another +example, the pattern can be written as ` def : Pat<(AOp $a, F32Attr:$b), ...>;` +and use `$a` and `$b` to refer to the captured input and attribute. But using +the ODS name directly in the pattern is also allowed. + +Also note that we only need to add `TypeConstraint` or `AttributeConstraint` +when we need to further limit the match criteria. If all valid cases to the op +are acceptable, then we can leave the constraint unspecified. + +`$_` is a special symbol to mean ignore capturing an argument. For example, +`def : Pat<(AOp $_, $b), ...>` means only `$b` is interesting to capture and +will be referenced later in result patterns. It's still possible to place +additional constraints even if the symbol is not to be captured; for such case, +you can simply use just the `TypeConstraint` or `AttributeConstraint` without a +bound symbol, for example, `def : Pat<(AOp $a, F32Attr), ...>`. + +#### Matching DAG of operations + +To match an DAG of ops, use nested `dag` objects: + +```tblgen + +def BOp : Op<"b_op"> { + let arguments = (ins); + + let results = (outs + AnyType:$b_output + ); +} + + +def : Pat<(AOp (BOp), $attr), ...>; +``` + +The above pattern matches an `AOp` whose only operand is generated by a `BOp`, +that is, the following MLIR code: + +```mlir +%0 = "b_op"() : () -> (...) +%1 = "a_op"(%0) {attr: ...} : () -> (...) +``` + +#### Binding op results + +To bind a symbol to the results of a matched op for later reference, attach the +symbol to the op itself: + +```tblgen +def : Pat<(AOp (BOp:$b_result), $attr), ...>; +``` + +The above will bind `$b_result` to the matched `BOp`'s result. (There are more +details regarding multi-result ops, which is covered +[later](#supporting-multi-result-ops).) + +### Result pattern + +The result pattern is for generating a DAG of operations. Arguments in the `dag` +object are intended to **reference** values captured in the source pattern and +potentially **apply transformations**. + +#### Referencing bound symbols + +For example, + +```tblgen +def COp : Op<"c_op"> { + let arguments = (ins + AnyType:$c_input, + AnyAttr:$c_attr + ); + + let results = (outs + AnyType:$c_output + ); +} + +def : Pat<(AOp $input, $attr), (COp $input, $attr)>; +``` + +In the above, `AOp`'s only operand and attribute are bound to `$input` and +`$attr`, respectively. We then reference them in the result pattern for +generating the `COp` by passing them in as arguments to `COp`'s `build()` +method. + +We can also reference symbols bound to matched op's results: + +```tblgen +def : Pat<(AOp (BOp:$b_result) $attr), (COp $b_result $attr)>; +``` + +In the above, we are using `BOp`'s result for building `COp`. + +#### Building operations + +Given that `COp` was specified with table-driven op definition, there will be +several `build()` methods generated for it. One of them has aggregated +parameters for result types, operands, and attributes in the signature: `void +COp::build(..., ArrayRef resultTypes, Array operands, +ArrayRef attr)`. The pattern in the above calls this `build()` +method for constructing the `COp`. + +In general, arguments in the result pattern will be passed directly to the +`build()` method to leverage the auto-generated `build()` method, list them in +the pattern by following the exact same order as the ODS `arguments` definition. +Otherwise, a custom `build()` method that matches the argument list is required. + +Right now all ODS-generated `build()` methods require specifying the result +type(s), unless the op has known traits like `SameOperandsAndResultType` that +we can use to auto-generate a `build()` method with result type deduction. +When generating an op to replace the result of the matched root op, we can use +the matched root op's result type when calling the ODS-generated builder. +Otherwise (e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or +generating an op with a nested result pattern), DRR will not be able to deduce +the result type(s). The pattern author will need to define a custom builder +that has result type deduction ability via `OpBuilder` in ODS. For example, +in the following pattern + +```tblgen +def : Pat<(AOp $input, $attr), (COp (AOp $input, $attr) $attr)>; +``` + +`AOp` is generated via a nested result pattern; DRR won't be able to deduce the +result type for it. A custom builder for `AOp` should be defined and it should +deduce the result type by itself. The builder should have the separate parameter +for each operand and attribute and deduce the result type internally by itself. +For example, for the above `AOp`, a possible builder is: + +```c++ + +void AOp::build(Builder *builder, OperationState &state, + Value input, Attribute attr) { + state.addOperands({input}); + state.addAttribute("a_attr", attr); + Type type = ...; // Deduce result type here + state.addTypes({type}); +} +``` + +Failing to define such a builder will result in an error at C++ compilation time +saying the call to `AOp::build()` cannot be resolved because of the number of +parameters mismatch. + +#### Generating DAG of operations + +`dag` objects can be nested to generate a DAG of operations: + +```tblgen +def : Pat<(AOp $input, $attr), (COp (BOp), $attr)>; +``` + +In the above, we generate a `BOp`, and then use its result to generate the `COp` +to replace the matched `AOp`. + +#### Binding op results + +In the result pattern, we can bind to the result(s) of a newly built op by +attaching symbols to the op. (But we **cannot** bind to op arguments given that +they are referencing previously bound symbols.) This is useful for reusing +newly created results where suitable. For example, + +```tblgen +def DOp : Op<"d_op"> { + let arguments = (ins + AnyType:$d_input1, + AnyType:$d_input2, + ); + + let results = (outs + AnyType:$d_output + ); +} + +def : Pat<(AOp $input, $ignored_attr), (DOp (BOp:$b_result) $b_result)>; +``` + +In this pattern, an `AOp` is matched and replaced with a `DOp` whose two +operands are from the result of a single `BOp`. This is only possible by binding +the result of the `BOp` to a name and reuse it for the second operand of the +`DOp` + +#### `NativeCodeCall`: transforming the generated op + +Sometimes the captured arguments are not exactly what we want so they cannot be +directly fed in as arguments to build the new op. For such cases, we can apply +transformations on the arguments by calling into C++ helper functions. This is +achieved by `NativeCodeCall`. + +For example, if we want to capture some op's attributes and group them as an +array attribute to construct a new op: + +```tblgen + +def TwoAttrOp : Op<"two_attr_op"> { + let arguments = (ins + AnyAttr:$op_attr1, + AnyAttr:$op_attr2 + ); + + let results = (outs + AnyType:$op_output + ); +} + +def OneAttrOp : Op<"one_attr_op"> { + let arguments = (ins + ArrayAttr:$op_attr + ); + + let results = (outs + AnyType:$op_output + ); +} +``` + +We can write a C++ helper function: + +```c++ +Attribute createArrayAttr(Builder &builder, Attribute a, Attribute b) { + return builder.getArrayAttr({a, b}); +} +``` + +And then write the pattern as: + +```tblgen +def createArrayAttr : NativeCodeCall<"createArrayAttr($_builder, $0, $1)">; + +def : Pat<(TwoAttrOp $attr1, $attr2), + (OneAttrOp (createArrayAttr $attr1, $attr2))>; +``` + +And make sure the generated C++ code from the above pattern has access to the +definition of the C++ helper function. + +In the above example, we are using a string to specialize the `NativeCodeCall` +template. The string can be an arbitrary C++ expression that evaluates into +some C++ object expected at the `NativeCodeCall` site (here it would be +expecting an array attribute). Typically the string should be a function call. + +Note that currently `NativeCodeCall` must return no more than one value or +attribute. This might change in the future. + +##### `NativeCodeCall` placeholders + +In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former +is called _special placeholder_, while the latter is called _positional +placeholder_. + +`NativeCodeCall` right now only supports two special placeholders: `$_builder` +and `$_self`: + +* `$_builder` will be replaced by the current `mlir::PatternRewriter`. +* `$_self` will be replaced with the entity `NativeCodeCall` is attached to. + +We have seen how `$_builder` can be used in the above; it allows us to pass a +`mlir::Builder` (`mlir::PatternRewriter` is a subclass of `mlir::OpBuilder`, +which is a subclass of `mlir::Builder`) to the C++ helper function to use the +handy methods on `mlir::Builder`. + +`$_self` is useful when we want to write something in the form of +`NativeCodeCall<"...">:$symbol`. For example, if we want to reverse the previous +example and decompose the array attribute into two attributes: + +```tblgen +class getNthAttr : NativeCodeCall<"$_self.getValue()[" # n # "]">; + +def : Pat<(OneAttrOp $attr), + (TwoAttrOp (getNthAttr<0>:$attr), (getNthAttr<1>:$attr)>; +``` + +In the above, `$_self` is substituted by the attribute bound by `$attr`, which +is `OnAttrOp`'s array attribute. + +Positional placeholders will be substituted by the `dag` object parameters at +the `NativeCodeCall` use site. For example, if we define `SomeCall : +NativeCodeCall<"someFn($1, $2, $0)">` and use it like `(SomeCall $in0, $in1, +$in2)`, then this will be translated into C++ call `someFn($in1, $in2, $in0)`. + +##### Customizing entire op building + +`NativeCodeCall` is not only limited to transforming arguments for building an +op; it can be also used to specify how to build an op entirely. An example: + +If we have a C++ function for building an op: + +```c++ +Operation *createMyOp(OpBuilder builder, Value input, Attribute attr); +``` + +We can wrap it up and invoke it like: + +```tblgen +def createMyOp : NativeCodeCall<"createMyOp($_builder, $0, $1)">; + +def : Pat<(... $input, $attr), (createMyOp $input, $attr)>; +``` + +### Supporting auxiliary ops + +A declarative rewrite rule supports multiple result patterns. One of the +purposes is to allow generating _auxiliary ops_. Auxiliary ops are operations +used for building the replacement ops; but they are not directly used for +replacement themselves. + +For the case of uni-result ops, if there are multiple result patterns, only the +value generated from the last result pattern will be used to replace the matched +root op's result; all other result patterns will be considered as generating +auxiliary ops. + +Normally we want to specify ops as nested `dag` objects if their def-use +relationship can be expressed in the way that an op's result can feed as the +argument to consuming op. But that is not always possible. For example, if we +want to allocate memory and store some computation (in pseudocode): + +```mlir +%dst = addi %lhs, %rhs +``` + +into + +```mlir +%shape = shape %lhs +%mem = alloc %shape +%sum = addi %lhs, %rhs +store %mem, %sum +%dst = load %mem +``` + +We cannot fit in with just one result pattern given `store` does not return a +value. Instead we can use multiple result patterns: + +```tblgen +def : Pattern<(AddIOp $lhs, $rhs), + [(StoreOp (AllocOp:$mem (ShapeOp %lhs)), (AddIOp $lhs, $rhs)), + (LoadOp $mem)]; +``` + +In the above we use the first result pattern to generate the first four ops, and +use the last pattern to generate the last op, which is used to replace the +matched op. + +### Supporting multi-result ops + +Multi-result ops bring extra complexity to declarative rewrite rules. We use +TableGen `dag` objects to represent ops in patterns; there is no native way to +indicate that an op generates multiple results. The approach adopted is based +on **naming convention**: a `__N` suffix is added to a symbol to indicate the +`N`-th result. + +#### `__N` suffix + +The `__N` suffix is specifying the `N`-th result as a whole (which can be +[variadic](#supporting-variadic-ops)). For example, we can bind a symbol to some +multi-result op and reference a specific result later: + +```tblgen +def ThreeResultOp : Op<"three_result_op"> { + let arguments = (ins ...); + + let results = (outs + AnyTensor:$op_output1, + AnyTensor:$op_output2, + AnyTensor:$op_output3 + ); +} + +def : Pattern<(ThreeResultOp:$results ...), + [(... $results__0), ..., (... $results__2), ...]>; +``` + +In the above pattern we bind `$results` to all the results generated by +`ThreeResultOp` and references its `$input1` and `$input3` later in the result +patterns. + +We can also bind a symbol and reference one of its specific result at the same +time, which is typically useful when generating multi-result ops: + +```tblgen +// TwoResultOp has similar definition as ThreeResultOp, but only has two +// results. + +def : Pattern<(TwoResultOp ...), + [(ThreeResultOp:$results__2, ...), + (replaceWithValue $results__0)]>; +``` + +In the above, we created a `ThreeResultOp` and bind `results` to its results, +and uses its last result (`$output3`) and first result (`$output1`) to replace +the `TwoResultOp`'s two results, respectively. + +#### Replacing multi-result ops + +The above example also shows how to replace a matched multi-result op. + +To replace a `N`-result op, the result patterns must generate at least `N` +declared values (see [Declared vs. actual value](#declared-vs-actual-value) for +definition). If there are more than `N` declared values generated, only the +last `N` declared values will be used to replace the matched op. Note that +because of the existence of multi-result op, one result pattern **may** generate +multiple declared values. So it means we do not necessarily need `N` result +patterns to replace an `N`-result op. For example, to replace an op with three +results, you can have + +```tblgen +// ThreeResultOp/TwoResultOp/OneResultOp generates three/two/one result(s), +// respectively. + +// Replace each result with a result generated from an individual op. +def : Pattern<(ThreeResultOp ...), + [(OneResultOp ...), (OneResultOp ...), (OneResultOp ...)]>; + +// Replace the first two results with two results generated from the same op. +def : Pattern<(ThreeResultOp ...), + [(TwoResultOp ...), (OneResultOp ...)]>; + +// Replace all three results with three results generated from the same op. +def : Pat<(ThreeResultOp ...), (ThreeResultOp ...)>; + +def : Pattern<(ThreeResultOp ...), + [(AuxiliaryOp ...), (ThreeResultOp ...)]>; +``` + +But using a single op to serve as both auxiliary op and replacement op is +forbidden, i.e., the following is not allowed because that the first +`TwoResultOp` generates two results but only the second result is used for +replacing the matched op's result: + +```tblgen +def : Pattern<(ThreeResultOp ...), + [(TwoResultOp ...), (TwoResultOp ...)]>; +``` + +### Supporting variadic ops + +#### Declared vs. actual value + +Before going into details on variadic op support, we need to define a few terms +regarding an op's values. + +* _Value_: either an operand or a result +* _Declared operand/result/value_: an operand/result/value statically declared + in ODS of the op +* _Actual operand/result/value_: an operand/result/value of an op instance at + runtime + +The above terms are needed because ops can have multiple results, and some of the +results can also be variadic. For example, + +```tblgen +def MultiVariadicOp : Op<"multi_variadic_op"> { + let arguments = (ins + AnyTensor:$input1, + Variadic:$input2, + AnyTensor:$input3 + ); + + let results = (outs + AnyTensor:$output1, + Variadic:$output2, + AnyTensor:$output3 + ); +} +``` + +We say the above op has 3 declared operands and 3 declared results. But at +runtime, an instance can have 3 values corresponding to `$input2` and 2 values +correspond to `$output2`; we say it has 5 actual operands and 4 actual +results. A variadic operand/result is a considered as a declared value that can +correspond to multiple actual values. + +[TODO] + +### Supplying additional constraints + +Constraints can be placed on op arguments when matching. But sometimes we need +to also place constraints on the matched op's results or sometimes need to limit +the matching with some constraints that cover both the arguments and the +results. The third parameter to `Pattern` (and `Pat`) is for this purpose. + +For example, we can write + +```tblgen +def HasNoUseOf: Constraint< + CPred<"$_self->use_begin() == $_self->use_end()">, "has no use">; + +def HasSameElementType : Constraint< + CPred<"$0.cast().getElementType() == " + "$1.cast().getElementType()">, + "has same element type">; + +def : Pattern<(TwoResultOp:$results $input), + [(...), (...)], + [(F32Tensor:$results__0), (HasNoUseOf:$results__1), + (HasSameElementShape $results__0, $input)]>; +``` + +You can + +* Use normal `TypeConstraint`s on previous bound symbols (the first result of + `TwoResultOp` must be a float tensor); +* Define new `Constraint` for previous bound symbols (the second result of + `TwoResultOp` must has no use); +* Apply constraints on multiple bound symbols (`$input` and `TwoResultOp`'s + first result must have the same element type). + +### Adjusting benefits + +The benefit of a `Pattern` is an integer value indicating the benefit of matching +the pattern. It determines the priorities of patterns inside the pattern rewrite +driver. A pattern with a higher benefit is applied before one with a lower +benefit. + +In DRR, a rule is set to have a benefit of the number of ops in the source +pattern. This is based on the heuristics and assumptions that: + +* Larger matches are more beneficial than smaller ones. +* If a smaller one is applied first the larger one may not apply anymore. + + +The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a +pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value. + +## Special directives + +[TODO] + +## Debugging Tips + +### Run `mlir-tblgen` to see the generated content + +TableGen syntax sometimes can be obscure; reading the generated content can be +a very helpful way to understand and debug issues. To build `mlir-tblgen`, run +`cmake --build . --target mlir-tblgen` in your build directory and find the +`mlir-tblgen` binary in the `bin/` subdirectory. All the supported generators +can be found via `mlir-tblgen --help`. + +To see the generated code, invoke `mlir-tblgen` with a specific generator by +providing include paths via `-I`. For example, + +```sh +# To see all the C++ pattern rewrite classes +mlir-tblgen --gen-rewriters -I /path/to/mlir/include /path/to/input/td/file +``` + +### Compilation error: no matching member function for call to 'build' + +This is because DRR is failing to call a `build()` method with result type +deduction ability. See [building operations](#building-operations) for more +details. + +[TableGen]: https://llvm.org/docs/TableGen/index.html +[OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td diff --git a/mlir/docs/DefiningAttributesAndTypes.md b/mlir/docs/DefiningAttributesAndTypes.md new file mode 100644 index 0000000000000000000000000000000000000000..60243e5fd57fc8937ed5f6db51fa243dc82eb06a --- /dev/null +++ b/mlir/docs/DefiningAttributesAndTypes.md @@ -0,0 +1,282 @@ +# Quickstart tutorial to defining custom dialect attributes and types + +This document is a quickstart to defining dialect specific extensions to the +[attribute](LangRef.md#attributes) and [type system](LangRef.md#type-system). +The main part of the tutorial focuses on defining types, but the instructions +are nearly identical for defining attributes. + +See [MLIR specification](LangRef.md) for more information about MLIR, the +structure of the IR, operations, etc. + +## Types + +Types in MLIR (like attributes, locations, and many other things) are +value-typed. This means that instances of `Type` should be passed around +by-value, as opposed to by-pointer or by-reference. The `Type` class in itself +acts as a wrapper around an internal storage object that is uniqued within an +instance of an `MLIRContext`. + +### Reserving a range of type kinds + +Types in MLIR rely on having a unique `kind` value to ensure that casting checks +remain extremely +efficient([rationale](Rationale.md#reserving-dialect-type-kinds). For a dialect +author, this means that a range of type `kind` values must be explicitly, and +statically, reserved. A dialect can reserve a range of values by adding a new +entry to the +[DialectSymbolRegistry](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/DialectSymbolRegistry.def). +To support out-of-tree and experimental dialects, the registry predefines a set +of privates ranges, `PRIVATE_EXPERIMENTAL_[0-9]`, that are free for immediate +use. + +```c++ +DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect +DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect + +// The following ranges are reserved for experimenting with MLIR dialects in a +// private context without having to register them here. +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0) +``` + +For the sake of this tutorial, we will use the predefined +`PRIVATE_EXPERIMENTAL_0` range. These definitions will provide a range in the +Type::Kind enum to use when defining the derived types. + +```c++ +namespace MyTypes { +enum Kinds { + // These kinds will be used in the examples below. + Simple = Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, + Complex +}; +} +``` + +### Defining the type class + +As described above, `Type` objects in MLIR are value-typed and rely on having an +implicitly internal storage object that holds the actual data for the type. When +defining a new `Type` it isn't always necessary to define a new storage class. +So before defining the derived `Type`, it's important to know which of the two +classes of `Type` we are defining. Some types are `primitives` meaning they do +not have any parameters and are singletons uniqued by kind, like the +[`index` type](LangRef.md#index-type). Parametric types on the other hand, have +additional information that differentiates different instances of the same +`Type` kind. For example the [`integer` type](LangRef.md#integer-type) has a +bitwidth, making `i8` and `i16` be different instances of +[`integer` type](LangRef.md#integer-type). + +#### Simple non-parametric types + +For simple parameterless types, we can jump straight into defining the derived +type class. Given that these types are uniqued solely on `kind`, we don't need +to provide our own storage class. + +```c++ +/// This class defines a simple parameterless type. All derived types must +/// inherit from the CRTP class 'Type::TypeBase'. It takes as template +/// parameters the concrete type (SimpleType), and the base class to use (Type). +/// 'Type::TypeBase' also provides several utility methods to simplify type +/// construction. +class SimpleType : public Type::TypeBase { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == MyTypes::Simple; } + + /// This method is used to get an instance of the 'SimpleType'. Given that + /// this is a parameterless type, it just needs to take the context for + /// uniquing purposes. + static SimpleType get(MLIRContext *context) { + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. + return Base::get(context, MyTypes::Simple); + } +}; +``` + +#### Parametric types + +Parametric types are those that have additional construction or uniquing +constraints outside of the type `kind`. As such, these types require defining a +type storage class. + +##### Defining a type storage + +Type storage objects contain all of the data necessary to construct and unique a +parametric type instance. The storage classes must obey the following: + +* Inherit from the base type storage class `TypeStorage`. +* Define a type alias, `KeyTy`, that maps to a type that uniquely identifies + an instance of the parent type. +* Provide a construction method that is used to allocate a new instance of the + storage class. + - `Storage *construct(TypeStorageAllocator &, const KeyTy &key)` +* Provide a comparison method between the storage and `KeyTy`. + - `bool operator==(const KeyTy &) const` +* Provide a method to generate the `KeyTy` from a list of arguments passed to + the uniquer. (Note: This is only necessary if the `KeyTy` cannot be default + constructed from these arguments). + - `static KeyTy getKey(Args...&& args)` +* Provide a method to hash an instance of the `KeyTy`. (Note: This is not + necessary if an `llvm::DenseMapInfo` specialization exists) + - `static llvm::hash_code hashKey(const KeyTy &)` + +Let's look at an example: + +```c++ +/// Here we define a storage class for a ComplexType, that holds a non-zero +/// integer and an integer type. +struct ComplexTypeStorage : public TypeStorage { + ComplexTypeStorage(unsigned nonZeroParam, Type integerType) + : nonZeroParam(nonZeroParam), integerType(integerType) {} + + /// The hash key for this storage is a pair of the integer and type params. + using KeyTy = std::pair; + + /// Define the comparison function for the key type. + bool operator==(const KeyTy &key) const { + return key == KeyTy(nonZeroParam, integerType); + } + + /// Define a hash function for the key type. + /// Note: This isn't necessary because std::pair, unsigned, and Type all have + /// hash functions already available. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, key.second); + } + + /// Define a construction function for the key type. + /// Note: This isn't necessary because KeyTy can be directly constructed with + /// the given parameters. + static KeyTy getKey(unsigned nonZeroParam, Type integerType) { + return KeyTy(nonZeroParam, integerType); + } + + /// Define a construction method for creating a new instance of this storage. + static ComplexTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + ComplexTypeStorage(key.first, key.second); + } + + unsigned nonZeroParam; + Type integerType; +}; +``` + +##### Type class definition + +Now that the storage class has been created, the derived type class can be +defined. This structure is similar to the +[simple type](#simple-non-parametric-types), except for a bit more of the +functionality of `Type::TypeBase` is put to use. + +```c++ +/// This class defines a parametric type. All derived types must inherit from +/// the CRTP class 'Type::TypeBase'. It takes as template parameters the +/// concrete type (ComplexType), the base class to use (Type), and the storage +/// class (ComplexTypeStorage). 'Type::TypeBase' also provides several utility +/// methods to simplify type construction and verification. +class ComplexType : public Type::TypeBase { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == MyTypes::Complex; } + + /// This method is used to get an instance of the 'ComplexType'. This method + /// asserts that all of the construction invariants were satisfied. To + /// gracefully handle failed construction, getChecked should be used instead. + static ComplexType get(MLIRContext *context, unsigned param, Type type) { + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. All parameters to the storage class are passed after the + // type kind. + return Base::get(context, MyTypes::Complex, param, type); + } + + /// This method is used to get an instance of the 'ComplexType', defined at + /// the given location. If any of the construction invariants are invalid, + /// errors are emitted with the provided location and a null type is returned. + /// Note: This method is completely optional. + static ComplexType getChecked(MLIRContext *context, unsigned param, Type type, + Location location) { + // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued + // instance of this type. All parameters to the storage class are passed + // after the type kind. + return Base::getChecked(location, context, MyTypes::Complex, param, type); + } + + /// This method is used to verify the construction invariants passed into the + /// 'get' and 'getChecked' methods. Note: This method is completely optional. + static LogicalResult verifyConstructionInvariants( + llvm::Optional loc, MLIRContext *context, unsigned param, + Type type) { + // Our type only allows non-zero parameters. + if (param == 0) { + if (loc) + context->emitError(loc) << "non-zero parameter passed to 'ComplexType'"; + return failure(); + } + // Our type also expects an integer type. + if (!type.isa()) { + if (loc) + context->emitError(loc) << "non integer-type passed to 'ComplexType'"; + return failure(); + } + return success(); + } + + /// Return the parameter value. + unsigned getParameter() { + // 'getImpl' returns a pointer to our internal storage instance. + return getImpl()->nonZeroParam; + } + + /// Return the integer parameter type. + IntegerType getParameterType() { + // 'getImpl' returns a pointer to our internal storage instance. + return getImpl()->integerType; + } +}; +``` + +### Registering types with a Dialect + +Once the dialect types have been defined, they must then be registered with a +`Dialect`. This is done via similar mechanism to +[operations](LangRef.md#operations), `addTypes`. + +```c++ +struct MyDialect : public Dialect { + MyDialect(MLIRContext *context) : Dialect(/*name=*/"mydialect", context) { + /// Add these types to the dialect. + addTypes(); + } +}; +``` + +### Parsing and Printing + +As a final step after registration, a dialect must override the `printType` and +`parseType` hooks. These enable native support for roundtripping the type in the +textual IR. + +## Attributes + +As stated in the introduction, the process for defining dialect attributes is +nearly identical to that of defining dialect types. That key difference is that +the things named `*Type` are generally now named `*Attr`. + +* `Type::TypeBase` -> `Attribute::AttrBase` +* `TypeStorageAllocator` -> `AttributeStorageAllocator` +* `addTypes` -> `addAttributes` + +Aside from that, all of the interfaces for uniquing and storage construction are +all the same. diff --git a/mlir/docs/DeveloperGuide.md b/mlir/docs/DeveloperGuide.md new file mode 100644 index 0000000000000000000000000000000000000000..745009959256b0913e929d225f438ae92017fcf5 --- /dev/null +++ b/mlir/docs/DeveloperGuide.md @@ -0,0 +1,107 @@ +# Developer Guide + +This document attempts to describe a few developer policies used in MLIR (such +as coding standards used) as well as development approach (such as, testing +methods). + +## Style guide + +MLIR follows the [LLVM style](https://llvm.org/docs/CodingStandards.html) guide. +We also adhere to the following (which deviate from or are not specified in the +LLVM style guide): + +* Adopts [camelBack](https://llvm.org/docs/Proposals/VariableNames.html); +* Except for IR units (Region, Block, and Operation), non-nullable output + arguments are passed by non-const reference in general. +* IR constructs are not designed for [const correctness](UsageOfConst.md). +* Do *not* use recursive algorithms if the recursion can't be bounded + statically: that is avoid recursion if there is a possible IR input that can + trigger a stack overflow (for example traversing use-def chains in a + recursive way). At the moment, we tolerate it for the two following cases: + * The nesting of the IR: we use recursion when traversing nested regions. + * Type nesting: recursion may be used for the nesting of composite types. +* Follow the `git` conventions for writing a commit message, in particular the + first line is the "title", it should be followed by an empty line and an + optional description. This [post](https://chris.beams.io/posts/git-commit/) + give examples and more details. + +Please run clang-format on the files you modified with the `.clang-format` +configuration file available in the root directory. Check the clang-format +[documentation](https://clang.llvm.org/docs/ClangFormat.html) for more details +on integrating it with your development environment. In particular, if clang is +installed system-wide, running `git clang-format origin/master` will update the +files in the working directory with the relevant formatting changes; don't +forget to include those to the commit. + +## Pass name and other command line options + +To avoid collision between options provided by different dialects, the naming +convention is to prepend the dialect name to every dialect-specific passes and +options in general. Options that are specific to a pass should also be prefixed +with the pass name. For example, the affine dialect provides a loop tiling pass +that is registered on the command line as `-affine-tile`, and with a tile size +option that can be set with `-affine-tile-size`. + +We also avoid `cl::opt` to provide pass options in favor of the +[pass options](WritingAPass.md#instance-specific-pass-options) mechanism. This +allows for these options to be serialized in a pass pipeline description, as +well as passing different options to multiple instances of a pass in the same +pipeline. + +## Testing guidelines + +See here for the [testing guide](TestingGuide.md). + +## Guidelines on contributing a new dialect (or important components) + +To contribute a dialect (or a major component in MLIR), it is usual to write an +overview "RFC" (it can be just a few informal paragraphs) and send it to the +MLIR mailing list. When accepting a new component to MLIR, the community is also +accepting the burden of maintaining it. The following points should be +considered when evaluating whether a dialect is a good fit for the core MLIR +repository: + +* What is the overall goal of the dialect? What is the first implementation + milestone? +* How does it fit into the MLIR dialect ecosystem? + * Connection: how does it connect to the existing dialects in a + compilation pipeline(s)? + * Consolidation: is there already a dialect with a similar goal or + matching abstractions; if so, can it be improved instead of adding a new + one? + * Reuse: how does it generalize to similar but slightly different + use-cases? +* What is the community of users that it is serving? +* Who are the future contributors/maintainers beyond those who propose the + dialect? + +On a practical aspect, we will expect the code to follow the other sections of +this document, with an emphasis on the documentation alongside the source code. + +It is prefered to upstream your dialects/components in small incremental patches +that can be individually reviewed. That is, after the initial RFC has been +agreed on, we encourage dialects to be built progressively by faster iterations +in-tree; as long as it is clear they evolve towards their milestones and goals. + +We have seen the following broad categories of dialects: + +* Edge dialects that model a representation external to MLIR. Examples include + LLVM, SPIR-V dialects, TensorFlow, XLA/HLO, ... Such dialects may be a + better fit for the project that contains the original representation instead + of being added to the MLIR repository. In particular, because MLIR will not + take an external dependency on another project. +* Structured Abstraction dialects that generalize common features of several + other dialects or introduce a programming model. Generalization is sometimes + demonstrated by having several dialects lower to or originate from a new + dialect. While additional abstractions may be useful, they should be traded + off against the additional complexity of the dialect ecosystem. Examples of + abstraction dialects include the GPU and Loop dialects. +* Transformation dialects that serve as input/output for program + transformations. These dialects are commonly introduced to materialize + transformation pre- and post-conditions in the IR, while conditions can be + obtained through analysis or through operation semantics. Examples include + Affine and Linalg dialects. + +While it can be useful to frame the goals of a proposal, this categorization is +not exhaustive or absolute, and the community is open to discussing any new +dialect beyond this taxonomy. diff --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md new file mode 100644 index 0000000000000000000000000000000000000000..69a30942c0039b041480750c6edd8e14b8a5138d --- /dev/null +++ b/mlir/docs/Diagnostics.md @@ -0,0 +1,402 @@ +# Introduction and Usage Guide to MLIR's Diagnostics Infrastructure + +[TOC] + +This document presents an introduction to using and interfacing with MLIR's +diagnostics infrastructure. + +See [MLIR specification](LangRef.md) for more information about MLIR, the +structure of the IR, operations, etc. + +## Source Locations + +Source location information is extremely important for any compiler, because it +provides a baseline for debuggability and error-reporting. MLIR provides several +different location types depending on the situational need. + +### CallSite Location + +``` +callsite-location ::= 'callsite' '(' location 'at' location ')' +``` + +An instance of this location allows for representing a directed stack of +location usages. This connects a location of a `callee` with the location of a +`caller`. + +### FileLineCol Location + +``` +filelinecol-location ::= string-literal ':' integer-literal ':' integer-literal +``` + +An instance of this location represents a tuple of file, line number, and column +number. This is similar to the type of location that you get from most source +languages. + +### Fused Location + +``` +fused-location ::= `fused` fusion-metadata? '[' location (location ',')* ']' +fusion-metadata ::= '<' attribute-value '>' +``` + +An instance of a `fused` location represents a grouping of several other source +locations, with optional metadata that describes the context of the fusion. +There are many places within a compiler in which several constructs may be fused +together, e.g. pattern rewriting, that normally result partial or even total +loss of location information. With `fused` locations, this is a non-issue. + +### Name Location + +``` +name-location ::= string-literal ('(' location ')')? +``` + +An instance of this location allows for attaching a name to a child location. +This can be useful for representing the locations of variable, or node, +definitions. + +### Opaque Location + +An instance of this location essentially contains a pointer to some data +structure that is external to MLIR and an optional location that can be used if +the first one is not suitable. Since it contains an external structure, only the +optional location is used during serialization. + +### Unknown Location + +``` +unknown-location ::= `unknown` +``` + +Source location information is an extremely integral part of the MLIR +infrastructure. As such, location information is always present in the IR, and +must explicitly be set to unknown. Thus an instance of the `unknown` location, +represents an unspecified source location. + +## Diagnostic Engine + +The `DiagnosticEngine` acts as the main interface for diagnostics in MLIR. It +manages the registration of diagnostic handlers, as well as the core API for +diagnostic emission. Handlers generally take the form of +`LogicalResult(Diagnostic &)`. If the result is `success`, it signals that the +diagnostic has been fully processed and consumed. If `failure`, it signals that +the diagnostic should be propagated to any previously registered handlers. It +can be interfaced with via an `MLIRContext` instance. + +```c++ +DiagnosticEngine engine = ctx->getDiagEngine(); + +/// Handle the reported diagnostic. +// Return success to signal that the diagnostic has either been fully processed, +// or failure if the diagnostic should be propagated to the previous handlers. +DiagnosticEngine::HandlerID id = engine.registerHandler( + [](Diagnostic &diag) -> LogicalResult { + bool should_propage_diagnostic = ...; + return failure(should_propage_diagnostic); +}); + + +// We can also elide the return value completely, in which the engine assumes +// that all diagnostics are consumed(i.e. a success() result). +DiagnosticEngine::HandlerID id = engine.registerHandler([](Diagnostic &diag) { + return; +}); + +// Unregister this handler when we are done. +engine.eraseHandler(id); +``` + +### Constructing a Diagnostic + +As stated above, the `DiagnosticEngine` holds the core API for diagnostic +emission. A new diagnostic can be emitted with the engine via `emit`. This +method returns an [InFlightDiagnostic](#inflight-diagnostic) that can be +modified further. + +```c++ +InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity); +``` + +Using the `DiagnosticEngine`, though, is generally not the preferred way to emit +diagnostics in MLIR. [`operation`](LangRef.md#operations) provides utility +methods for emitting diagnostics: + +```c++ +// `emit` methods available in the mlir namespace. +InFlightDiagnostic emitError/Remark/Warning(Location); + +// These methods use the location attached to the operation. +InFlightDiagnostic Operation::emitError/Remark/Warning(); + +// This method creates a diagnostic prefixed with "'op-name' op ". +InFlightDiagnostic Operation::emitOpError(); +``` + +## Diagnostic + +A `Diagnostic` in MLIR contains all of the necessary information for reporting a +message to the user. A `Diagnostic` essentially boils down to three main +components: + +* [Source Location](#source-locations) +* Severity Level + - Error, Note, Remark, Warning +* Diagnostic Arguments + - The diagnostic arguments are used when constructing the output message. + +### Appending arguments + +One a diagnostic has been constructed, the user can start composing it. The +output message of a diagnostic is composed of a set of diagnostic arguments that +have been attached to it. New arguments can be attached to a diagnostic in a few +different ways: + +```c++ +// A few interesting things to use when composing a diagnostic. +Attribute fooAttr; +Type fooType; +SmallVector fooInts; + +// Diagnostics can be composed via the streaming operators. +op->emitError() << "Compose an interesting error: " << fooAttr << ", " << fooType + << ", (" << fooInts << ')'; + +// This could generate something like (FuncAttr:@foo, IntegerType:i32, {0,1,2}): +"Compose an interesting error: @foo, i32, (0, 1, 2)" +``` + +### Attaching notes + +Unlike many other compiler frameworks, notes in MLIR cannot be emitted directly. +They must be explicitly attached to another diagnostic non-note diagnostic. When +emitting a diagnostic, notes can be directly attached via `attachNote`. When +attaching a note, if the user does not provide an explicit source location the +note will inherit the location of the parent diagnostic. + +```c++ +// Emit a note with an explicit source location. +op->emitError("...").attachNote(noteLoc) << "..."; + +// Emit a note that inherits the parent location. +op->emitError("...").attachNote() << "..."; +``` + +## InFlight Diagnostic + +Now that [Diagnostics](#diagnostic) have been explained, we introduce the +`InFlightDiagnostic`. is an RAII wrapper around a diagnostic that is set to be +reported. This allows for modifying a diagnostic while it is still in flight. If +it is not reported directly by the user it will automatically report when +destroyed. + +```c++ +{ + InFlightDiagnostic diag = op->emitError() << "..."; +} // The diagnostic is automatically reported here. +``` + +## Diagnostic Configuration Options + +Several options are provided to help control and enhance the behavior of +diagnostics. These options are listed below: + +### Print Operation On Diagnostic + +Command Line Flag: `-mlir-print-op-on-diagnostic` + +When a diagnostic is emitted on an operation, via `Operation::emitError/...`, +the textual form of that operation is printed and attached as a note to the +diagnostic. This option is useful for understanding the current form of an +operation that may be invalid, especially when debugging verifier failures. An +example output is shown below: + +```shell +test.mlir:3:3: error: 'module_terminator' op expects parent op 'module' + "module_terminator"() : () -> () + ^ +test.mlir:3:3: note: see current operation: "module_terminator"() : () -> () + "module_terminator"() : () -> () + ^ +``` + +### Print StackTrace On Diagnostic + +Command Line Flag: `-mlir-print-stacktrace-on-diagnostic` + +When a diagnostic is emitted, attach the current stack trace as a note to the +diagnostic. This option is useful for understanding which part of the compiler +generated certain diagnostics. An example output is shown below: + +```shell +test.mlir:3:3: error: 'module_terminator' op expects parent op 'module' + "module_terminator"() : () -> () + ^ +test.mlir:3:3: note: diagnostic emitted with trace: + #0 0x000055dd40543805 llvm::sys::PrintStackTrace(llvm::raw_ostream&) llvm/lib/Support/Unix/Signals.inc:553:11 + #1 0x000055dd3f8ac162 emitDiag(mlir::Location, mlir::DiagnosticSeverity, llvm::Twine const&) /lib/IR/Diagnostics.cpp:292:7 + #2 0x000055dd3f8abe8e mlir::emitError(mlir::Location, llvm::Twine const&) /lib/IR/Diagnostics.cpp:304:10 + #3 0x000055dd3f998e87 mlir::Operation::emitError(llvm::Twine const&) /lib/IR/Operation.cpp:324:29 + #4 0x000055dd3f99d21c mlir::Operation::emitOpError(llvm::Twine const&) /lib/IR/Operation.cpp:652:10 + #5 0x000055dd3f96b01c mlir::OpTrait::HasParent::Impl::verifyTrait(mlir::Operation*) /mlir/IR/OpDefinition.h:897:18 + #6 0x000055dd3f96ab38 mlir::Op::Impl, mlir::OpTrait::IsTerminator>::BaseVerifier::Impl, mlir::OpTrait::IsTerminator >::verifyTrait(mlir::Operation*) /mlir/IR/OpDefinition.h:1052:29 + # ... + "module_terminator"() : () -> () + ^ +``` + +## Common Diagnostic Handlers + +To interface with the diagnostics infrastructure, users will need to register a +diagnostic handler with the [`DiagnosticEngine`](#diagnostic-engine). +Recognizing the many users will want the same handler functionality, MLIR +provides several common diagnostic handlers for immediate use. + +### Scoped Diagnostic Handler + +This diagnostic handler is a simple RAII class that registers and unregisters a +given diagnostic handler. This class can be either be used directly, or in +conjunction with a derived diagnostic handler. + +```c++ +// Construct the handler directly. +MLIRContext context; +ScopedDiagnosticHandler scopedHandler(&context, [](Diagnostic &diag) { + ... +}); + +// Use this handler in conjunction with another. +class MyDerivedHandler : public ScopedDiagnosticHandler { + MyDerivedHandler(MLIRContext *ctx) : ScopedDiagnosticHandler(ctx) { + // Set the handler that should be RAII managed. + setHandler([&](Diagnostic diag) { + ... + }); + } +}; +``` + +### SourceMgr Diagnostic Handler + +This diagnostic handler is a wrapper around an llvm::SourceMgr instance. It +provides support for displaying diagnostic messages inline with a line of a +respective source file. This handler will also automatically load newly seen +source files into the SourceMgr when attempting to display the source line of a +diagnostic. Example usage of this handler can be seen in the `mlir-opt` tool. + +```shell +$ mlir-opt foo.mlir + +/tmp/test.mlir:6:24: error: expected non-function type +func @foo() -> (index, ind) { + ^ +``` + +To use this handler in your tool, add the following: + +```c++ +SourceMgr sourceMgr; +MLIRContext context; +SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); +``` + +### SourceMgr Diagnostic Verifier Handler + +This handler is a wrapper around a llvm::SourceMgr that is used to verify that +certain diagnostics have been emitted to the context. To use this handler, +annotate your source file with expected diagnostics in the form of: + +* `expected-(error|note|remark|warning) {{ message }}` + +A few examples are shown below: + +```mlir +// Expect an error on the same line. +func @bad_branch() { + br ^missing // expected-error {{reference to an undefined block}} +} + +// Expect an error on an adjacent line. +func @foo(%a : f32) { + // expected-error@+1 {{unknown comparison predicate "foo"}} + %result = cmpf "foo", %a, %a : f32 + return +} + +// Expect an error on the next line that does not contain a designator. +// expected-remark@below {{remark on function below}} +// expected-remark@below {{another remark on function below}} +func @bar(%a : f32) + +// Expect an error on the previous line that does not contain a designator. +func @baz(%a : f32) +// expected-remark@above {{remark on function above}} +// expected-remark@above {{another remark on function above}} + +``` + +The handler will report an error if any unexpected diagnostics were seen, or if +any expected diagnostics weren't. + +```shell +$ mlir-opt foo.mlir + +/tmp/test.mlir:6:24: error: unexpected error: expected non-function type +func @foo() -> (index, ind) { + ^ + +/tmp/test.mlir:15:4: error: expected remark "expected some remark" was not produced +// expected-remark {{expected some remark}} + ^~~~~~~~~~~~~~~~~~~~~~~~~~ +``` + +Similarly to the [SourceMgr Diagnostic Handler](#sourcemgr-diagnostic-handler), +this handler can be added to any tool via the following: + +```c++ +SourceMgr sourceMgr; +MLIRContext context; +SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); +``` + +### Parallel Diagnostic Handler + +MLIR is designed from the ground up to be multi-threaded. One important to thing +to keep in mind when multi-threading is determinism. This means that the +behavior seen when operating on multiple threads is the same as when operating +on a single thread. For diagnostics, this means that the ordering of the +diagnostics is the same regardless of the amount of threads being operated on. +The ParallelDiagnosticHandler is introduced to solve this problem. + +After creating a handler of this type, the only remaining step is to ensure that +each thread that will be emitting diagnostics to the handler sets a respective +'orderID'. The orderID corresponds to the order in which diagnostics would be +emitted when executing synchronously. For example, if we were processing a list +of operations [a, b, c] on a single-thread. Diagnostics emitted while processing +operation 'a' would be emitted before those for 'b' or 'c'. This corresponds 1-1 +with the 'orderID'. The thread that is processing 'a' should set the orderID to +'0'; the thread processing 'b' should set it to '1'; and so on and so forth. +This provides a way for the handler to deterministically order the diagnostics +that it receives given the thread that it is receiving on. + +A simple example is shown below: + +```c++ +MLIRContext *context = ...; +ParallelDiagnosticHandler handler(context); + +// Process a list of operations in parallel. +std::vector opsToProcess = ...; +llvm::for_each_n(llvm::parallel::par, 0, opsToProcess.size(), + [&](size_t i) { + // Notify the handler that we are processing the i'th operation. + handler.setOrderIDForThread(i); + auto *op = opsToProcess[i]; + ... + + // Notify the handler that we are finished processing diagnostics on this + // thread. + handler.eraseOrderIDForThread(); +}); +``` diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md new file mode 100644 index 0000000000000000000000000000000000000000..e6b652f21913afb1c26923e564afc70ff4f5bd90 --- /dev/null +++ b/mlir/docs/DialectConversion.md @@ -0,0 +1,277 @@ +# Dialect Conversion + +This document describes a framework in MLIR in which to perform operation +conversions between, and within dialects. This framework allows for transforming +illegal operations to those supported by a provided conversion target, via a set +of pattern-based operation rewriting patterns. + +[TOC] + +To utilize the framework, a few things must be provided: + +* A [Conversion Target](#conversion-target) +* A set of [Rewrite Patterns](#rewrite-pattern-specification) +* A [Type Converter](#type-conversion) (Optional) + +## Modes of Conversion + +When applying a conversion to a set of operations, there are several conversion +modes that can be selected from: + +* Partial Conversion + + - A partial conversion will legalize as many operations to the target as + possible, but will allow pre-existing operations that were not + explicitly marked as `illegal` to remain unconverted. This allows for + partially lowering parts of the module in the presence of unknown + operations. + - A partial conversion can be applied via `applyPartialConversion`. + +* Full Conversion + + - A full conversion is only successful if all operations are properly + legalized to the given conversion target. This ensures that only known + operations will exist after the conversion process. + - A full conversion can be applied via `applyFullConversion`. + +* Analysis Conversion + + - An analysis conversion will analyze which operations are legalizable to + the given conversion target if a conversion were to be applied. Note + that no rewrites, or transformations, are actually applied to the input + operations. + - An analysis conversion can be applied via `applyAnalysisConversion`. + +## Conversion Target + +The conversion target is the formal definition of what is considered to be legal +during the conversion process. The final operations generated by the conversion +framework must be marked as legal on the `ConversionTarget` for the rewrite to +be a success. Existing operations need not always be legal, though; see the +different conversion modes for why. Operations and dialects may be marked with +any of the provided legality actions below: + +* Legal + + - This action signals that every instance of a given operation is legal, + i.e. any combination of attributes, operands, types, etc. are valid. + +* Dynamic + + - This action signals that only some instances of a given operation are + legal. This allows for defining fine-tune constraints, e.g. saying that + `addi` is only legal when operating on 32-bit integers. + - If a specific handler is not provided when setting the action, the + target must override the `isDynamicallyLegal` hook provided by + `ConversionTarget`. + +* Illegal + + - This action signals that no instance of a given operation is legal. + Operations marked as `illegal` must always be converted for the + conversion to be successful. This action also allows for selectively + marking specific operations as illegal in an otherwise legal dialect. + +An example conversion target is shown below: + +```c++ +struct MyTarget : public ConversionTarget { + MyTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + //-------------------------------------------------------------------------- + // Marking an operation as Legal: + + /// Mark all operations within the LLVM dialect are legal. + addLegalDialects(); + + /// Mark `std.constant` op is always legal on this target. + addLegalOps(); + + //-------------------------------------------------------------------------- + // Marking an operation as dynamically legal. + + /// Mark all operations within Affine dialect have dynamic legality + /// constraints. + addDynamicallyLegalDialects(); + + /// Mark `std.return` as dynamically legal. + addDynamicallyLegalOp(); + + /// Mark `std.return` as dynamically legal, but provide a specific legality + /// callback. + addDynamicallyLegalOp([](ReturnOp op) { ... }); + + //-------------------------------------------------------------------------- + // Marking an operation as illegal. + + /// All operations within the GPU dialect are illegal. + addIllegalDialect(); + + /// Mark `std.br` and `std.cond_br` as illegal. + addIllegalOp(); + } + + /// Implement the default legalization handler to handle operations marked as + /// dynamically legal that were not provided with an explicit handler. + bool isDynamicallyLegal(Operation *op) override { ... } +}; +``` + +### Recursive Legality + +In some cases, it may be desirable to mark entire regions of operations as +legal. This provides an additional granularity of context to the concept of +"legal". The `ConversionTarget` supports marking operations, that were +previously added as `Legal` or `Dynamic`, as `recursively` legal. Recursive +legality means that if an operation instance is legal, either statically or +dynamically, all of the operations nested within are also considered legal. An +operation can be marked via `markOpRecursivelyLegal<>`: + +```c++ +ConversionTarget &target = ...; + +/// The operation must first be marked as `Legal` or `Dynamic`. +target.addLegalOp(...); +target.addDynamicallyLegalOp(...); + +/// Mark the operation as always recursively legal. +target.markOpRecursivelyLegal(); +/// Mark optionally with a callback to allow selective marking. +target.markOpRecursivelyLegal([](Operation *op) { ... }); +/// Mark optionally with a callback to allow selective marking. +target.markOpRecursivelyLegal([](MyOp op) { ... }); +``` + +## Rewrite Pattern Specification + +After the conversion target has been defined, a set of legalization patterns +must be provided to transform illegal operations into legal ones. The patterns +supplied here, that do not [require type changes](#conversion-patterns), are the +same as those described in the +[quickstart rewrites guide](QuickstartRewrites.md#adding-patterns), but have a +few additional [restrictions](#restrictions). The patterns provided do not need +to generate operations that are directly legal on the target. The framework will +automatically build a graph of conversions to convert non-legal operations into +a set of legal ones. + +As an example, say you define a target that supports one operation: `foo.add`. +When providing the following patterns: [`bar.add` -> `baz.add`, `baz.add` -> +`foo.add`], the framework will automatically detect that it can legalize +`baz.add` -> `foo.add` even though a direct conversion does not exist. This +means that you don’t have to define a direct legalization pattern for `bar.add` +-> `foo.add`. + +### Restrictions + +The framework processes operations in topological order, trying to legalize them +individually. As such, patterns used in the conversion framework have a few +additional restrictions: + +1. If a pattern matches, it must erase or replace the op it matched on. + Operations can *not* be updated in place. +2. Match criteria should not be based on the IR outside of the op itself. The + preceding ops will already have been processed by the framework (although it + may not update uses), and the subsequent IR will not yet be processed. This + can create confusion if a pattern attempts to match against a sequence of + ops (e.g. rewrite A + B -> C). That sort of rewrite should be performed in a + separate pass. + +## Type Conversion + +It is sometimes necessary as part of a conversion to convert the set types of +being operated on. In these cases, a `TypeConverter` object may be defined that +details how types should be converted. The `TypeConverter` is used by patterns +and by the general conversion infrastructure to convert the signatures of blocks +and regions. + +### Type Converter + +As stated above, the `TypeConverter` contains several hooks for detailing how to +convert types. Several of these hooks are detailed below: + +```c++ +class TypeConverter { + public: + /// This hook allows for converting a type. This function should return + /// failure if no valid conversion exists, success otherwise. If the new set + /// of types is empty, the type is removed and any usages of the existing + /// value are expected to be removed during conversion. + virtual LogicalResult convertType(Type t, SmallVectorImpl &results); + + /// This hook simplifies defining 1-1 type conversions. This function returns + /// the type to convert to on success, and a null type on failure. + virtual Type convertType(Type t); + + /// This hook allows for materializing a conversion from a set of types into + /// one result type by generating a cast operation of some kind. The generated + /// operation should produce one result, of 'resultType', with the provided + /// 'inputs' as operands. This hook must be overridden when a type conversion + /// results in more than one type, or if a type conversion may persist after + /// the conversion has finished. + virtual Operation *materializeConversion(PatternRewriter &rewriter, + Type resultType, + ArrayRef inputs, + Location loc); +}; +``` + +### Conversion Patterns + +When type conversion comes into play, the general Rewrite Patterns can no longer +be used. This is due to the fact that the operands of the operation being +matched will not correspond with the operands of the correct type as determined +by `TypeConverter`. The operation rewrites on type boundaries must thus use a +special pattern, the `ConversionPattern`. This pattern provides, as an +additional argument to the `matchAndRewrite` and `rewrite` methods, the set of +remapped operands corresponding to the desired type. These patterns also utilize +a special `PatternRewriter`, `ConversionPatternRewriter`, that provides special +hooks for use with the conversion infrastructure. + +```c++ +struct MyConversionPattern : public ConversionPattern { + /// The `matchAndRewrite` hooks on ConversionPatterns take an additional + /// `operands` parameter, containing the remapped operands of the original + /// operation. + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const; +}; +``` + +These patterns have the same [restrictions](#restrictions) as the basic rewrite +patterns used in dialect conversion. + +### Region Signature Conversion + +From the perspective of type conversion, the entry block to a region is often +special. The types of the entry block arguments are often tied semantically to +details on the operation, e.g. FuncOp, AffineForOp, etc. Given this, the +conversion of the types for this block must be done explicitly via a conversion +pattern. To convert the signature of a region entry block, a custom hook on the +ConversionPatternRewriter must be invoked `applySignatureConversion`. A +signature conversion, `TypeConverter::SignatureConversion`, can be built +programmatically: + +```c++ +class SignatureConversion { +public: + /// Remap an input of the original signature with a new set of types. The + /// new types are appended to the new signature conversion. + void addInputs(unsigned origInputNo, ArrayRef types); + + /// Append new input types to the signature conversion, this should only be + /// used if the new types are not intended to remap an existing input. + void addInputs(ArrayRef types); + + /// Remap an input of the original signature with a range of types in the + /// new signature. + void remapInput(unsigned origInputNo, unsigned newInputNo, + unsigned newInputCount = 1); + + /// Remap an input of the original signature to another `replacement` + /// value. This drops the original argument. + void remapInput(unsigned origInputNo, Value replacement); +}; +``` + +The `TypeConverter` provides several default utilities for signature conversion: +`convertSignatureArg`/`convertBlockSignature`. diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md new file mode 100644 index 0000000000000000000000000000000000000000..c5dcf6a679027082da0618198b21734bbf8aeacf --- /dev/null +++ b/mlir/docs/Dialects/Affine.md @@ -0,0 +1,610 @@ +# Affine Dialect + +This dialect provides a powerful abstraction for affine operations and analyses. + +[TOC] + +## Polyhedral Structures + +MLIR uses techniques from polyhedral compilation to make dependence analysis and +loop transformations efficient and reliable. This section introduces some of the +core concepts that are used throughout the document. + +### Dimensions and Symbols + +Dimensions and symbols are the two kinds of identifiers that can appear in the +polyhedral structures, and are always of [`index`](../LangRef.md#index-type) +type. Dimensions are declared in parentheses and symbols are declared in square +brackets. + +Examples: + +```mlir +// A 2d to 3d affine mapping. +// d0/d1 are dimensions, s0 is a symbol +#affine_map2to3 = (d0, d1)[s0] -> (d0, d1 + s0, d1 - s0) +``` + +Dimensional identifiers correspond to the dimensions of the underlying structure +being represented (a map, set, or more concretely a loop nest or a tensor); for +example, a three-dimensional loop nest has three dimensional identifiers. Symbol +identifiers represent an unknown quantity that can be treated as constant for a +region of interest. + +Dimensions and symbols are bound to SSA values by various operations in MLIR and +use the same parenthesized vs square bracket list to distinguish the two. + +Syntax: + +``` +// Uses of SSA values that are passed to dimensional identifiers. +dim-use-list ::= `(` ssa-use-list? `)` + +// Uses of SSA values that are used to bind symbols. +symbol-use-list ::= `[` ssa-use-list? `]` + +// Most things that bind SSA values bind dimensions and symbols. +dim-and-symbol-use-list ::= dim-use-list symbol-use-list? +``` + +SSA values bound to dimensions and symbols must always have 'index' type. + +Example: + +```mlir +#affine_map2to3 = (d0, d1)[s0] -> (d0, d1 + s0, d1 - s0) +// Binds %N to the s0 symbol in affine_map2to3. +%x = alloc()[%N] : memref<40x50xf32, #affine_map2to3> +``` + +### Restrictions on Dimensions and Symbols + +The affine dialect imposes certain restrictions on dimension and symbolic +identifiers to enable powerful analysis and transformation. A symbolic +identifier can be bound to an SSA value that is either an argument to the +function, a value defined at the top level of that function (outside of all +loops and if operations), the result of a +[`constant` operation](Standard.md#constant-operation), or the result of an +[`affine.apply` operation](#affineapply-operation) that recursively takes as +arguments any symbolic identifiers, or the result of a [`dim` +operation](Standard.md#dim-operation) on either a memref that is a function +argument or a memref where the corresponding dimension is either static or a +dynamic one in turn bound to a symbolic identifier. Dimensions may be bound not +only to anything that a symbol is bound to, but also to induction variables of +enclosing [`affine.for` operations](#affinefor-operation), and the result of an +[`affine.apply` operation](#affineapply-operation) (which recursively may use +other dimensions and symbols). + +### Affine Expressions + +Syntax: + +``` +affine-expr ::= `(` affine-expr `)` + | affine-expr `+` affine-expr + | affine-expr `-` affine-expr + | `-`? integer-literal `*` affine-expr + | affine-expr `ceildiv` integer-literal + | affine-expr `floordiv` integer-literal + | affine-expr `mod` integer-literal + | `-`affine-expr + | bare-id + | `-`? integer-literal + +multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` +``` + +`ceildiv` is the ceiling function which maps the result of the division of its +first argument by its second argument to the smallest integer greater than or +equal to that result. `floordiv` is a function which maps the result of the +division of its first argument by its second argument to the largest integer +less than or equal to that result. `mod` is the modulo operation: since its +second argument is always positive, its results are always positive in our +usage. The `integer-literal` operand for ceildiv, floordiv, and mod is always +expected to be positive. `bare-id` is an identifier which must have type +[index](../LangRef.md#index-type). The precedence of operations in an affine +expression are ordered from highest to lowest in the order: (1) +parenthesization, (2) negation, (3) modulo, multiplication, floordiv, and +ceildiv, and (4) addition and subtraction. All of these operators associate from +left to right. + +A _multidimensional affine expression_ is a comma separated list of +one-dimensional affine expressions, with the entire list enclosed in +parentheses. + +**Context:** An affine function, informally, is a linear function plus a +constant. More formally, a function f defined on a vector $$\vec{v} \in +\mathbb{Z}^n$$ is a multidimensional affine function of $$\vec{v}$$ if +$$f(\vec{v})$$ can be expressed in the form $$M \vec{v} + \vec{c}$$ where $$M$$ +is a constant matrix from $$\mathbb{Z}^{m \times n}$$ and $$\vec{c}$$ is a +constant vector from $$\mathbb{Z}$$. $$m$$ is the dimensionality of such an +affine function. MLIR further extends the definition of an affine function to +allow 'floordiv', 'ceildiv', and 'mod' with respect to positive integer +constants. Such extensions to affine functions have often been referred to as +quasi-affine functions by the polyhedral compiler community. MLIR uses the term +'affine map' to refer to these multidimensional quasi-affine functions. As +examples, $$(i+j+1, j)$$, $$(i \mod 2, j+i)$$, $$(j, i/4, i \mod 4)$$, $$(2i+1, +j)$$ are two-dimensional affine functions of $$(i, j)$$, but $$(i \cdot j, +i^2)$$, $$(i \mod j, i/j)$$ are not affine functions of $$(i, j)$$. + +### Affine Maps + +Syntax: + +``` +affine-map-inline + ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr +``` + +The identifiers in the dimensions and symbols lists must be unique. These are +the only identifiers that may appear in 'multi-dim-affine-expr'. Affine maps +with one or more symbols in its specification are known as "symbolic affine +maps", and those with no symbols as "non-symbolic affine maps". + +**Context:** Affine maps are mathematical functions that transform a list of +dimension indices and symbols into a list of results, with affine expressions +combining the indices and symbols. Affine maps distinguish between +[indices and symbols](#dimensions-and-symbols) because indices are inputs to the +affine map when the map is called (through an operation such as +[affine.apply](#affineapply-operation)), whereas symbols are bound when +the map is established (e.g. when a memref is formed, establishing a +memory [layout map](../LangRef.md#layout-map)). + +Affine maps are used for various core structures in MLIR. The restrictions we +impose on their form allows powerful analysis and transformation, while keeping +the representation closed with respect to several operations of interest. + +#### Named affine mappings + +Syntax: + +``` +affine-map-id ::= `#` suffix-id + +// Definitions of affine maps are at the top of the file. +affine-map-def ::= affine-map-id `=` affine-map-inline +module-header-def ::= affine-map-def + +// Uses of affine maps may use the inline form or the named form. +affine-map ::= affine-map-id | affine-map-inline +``` + +Affine mappings may be defined inline at the point of use, or may be hoisted to +the top of the file and given a name with an affine map definition, and used by +name. + +Examples: + +```mlir +// Affine map out-of-line definition and usage example. +#affine_map42 = (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2) + +// Use an affine mapping definition in an alloc operation, binding the +// SSA value %N to the symbol s0. +%a = alloc()[%N] : memref<4x4xf32, #affine_map42> + +// Same thing with an inline affine mapping definition. +%b = alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)> +``` + +### Semi-affine maps + +Semi-affine maps are extensions of affine maps to allow multiplication, +`floordiv`, `ceildiv`, and `mod` with respect to symbolic identifiers. +Semi-affine maps are thus a strict superset of affine maps. + +Syntax of semi-affine expressions: + +``` +semi-affine-expr ::= `(` semi-affine-expr `)` + | semi-affine-expr `+` semi-affine-expr + | semi-affine-expr `-` semi-affine-expr + | symbol-or-const `*` semi-affine-expr + | semi-affine-expr `ceildiv` symbol-or-const + | semi-affine-expr `floordiv` symbol-or-const + | semi-affine-expr `mod` symbol-or-const + | bare-id + | `-`? integer-literal + +symbol-or-const ::= `-`? integer-literal | symbol-id + +multi-dim-semi-affine-expr ::= `(` semi-affine-expr (`,` semi-affine-expr)* `)` +``` + +The precedence and associativity of operations in the syntax above is the same +as that for [affine expressions](#affine-expressions). + +Syntax of semi-affine maps: + +``` +semi-affine-map-inline + ::= dim-and-symbol-id-lists `->` multi-dim-semi-affine-expr +``` + +Semi-affine maps may be defined inline at the point of use, or may be hoisted to +the top of the file and given a name with a semi-affine map definition, and used +by name. + +``` +semi-affine-map-id ::= `#` suffix-id + +// Definitions of semi-affine maps are at the top of file. +semi-affine-map-def ::= semi-affine-map-id `=` semi-affine-map-inline +module-header-def ::= semi-affine-map-def + +// Uses of semi-affine maps may use the inline form or the named form. +semi-affine-map ::= semi-affine-map-id | semi-affine-map-inline +``` + +### Integer Sets + +An integer set is a conjunction of affine constraints on a list of identifiers. +The identifiers associated with the integer set are separated out into two +classes: the set's dimension identifiers, and the set's symbolic identifiers. +The set is viewed as being parametric on its symbolic identifiers. In the +syntax, the list of set's dimension identifiers are enclosed in parentheses +while its symbols are enclosed in square brackets. + +Syntax of affine constraints: + +``` +affine-constraint ::= affine-expr `>=` `0` + | affine-expr `==` `0` +affine-constraint-conjunction ::= affine-constraint (`,` affine-constraint)* +``` + +Integer sets may be defined inline at the point of use, or may be hoisted to the +top of the file and given a name with an integer set definition, and used by +name. + +``` +integer-set-id ::= `#` suffix-id + +integer-set-inline + ::= dim-and-symbol-id-lists `:` '(' affine-constraint-conjunction? ')' + +// Declarations of integer sets are at the top of the file. +integer-set-decl ::= integer-set-id `=` integer-set-inline + +// Uses of integer sets may use the inline form or the named form. +integer-set ::= integer-set-id | integer-set-inline +``` + +The dimensionality of an integer set is the number of identifiers appearing in +dimension list of the set. The affine-constraint non-terminals appearing in the +syntax above are only allowed to contain identifiers from dims and symbols. A +set with no constraints is a set that is unbounded along all of the set's +dimensions. + +Example: + +```mlir +// A example two-dimensional integer set with two symbols. +#set42 = (d0, d1)[s0, s1] + : (d0 >= 0, -d0 + s0 - 1 >= 0, d1 >= 0, -d1 + s1 - 1 >= 0) + +// Inside a Region +affine.if #set42(%i, %j)[%M, %N] { + ... +} +``` + +`d0` and `d1` correspond to dimensional identifiers of the set, while `s0` and +`s1` are symbol identifiers. + +## Operations + +#### 'affine.apply' operation + +Syntax: + +``` +operation ::= ssa-id `=` `affine.apply` affine-map dim-and-symbol-use-list +``` + +The `affine.apply` operation applies an +[affine mapping](#affine-expressions) to a list of SSA values, +yielding a single SSA value. The number of dimension and symbol arguments to +affine.apply must be equal to the respective number of dimensional and symbolic +inputs to the affine mapping; the `affine.apply` operation always returns one +value. The input operands and result must all have 'index' type. + +Example: + +```mlir +#map10 = (d0, d1) -> (d0 floordiv 8 + d1 floordiv 128) +... +%1 = affine.apply #map10 (%s, %t) + +// Inline example. +%2 = affine.apply (i)[s0] -> (i+s0) (%42)[%n] +``` + +#### 'affine.for' operation + +Syntax: + +``` +operation ::= `affine.for` ssa-id `=` lower-bound `to` upper-bound + (`step` integer-literal)? `{` op* `}` + +lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound +upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound +shorthand-bound ::= ssa-id | `-`? integer-literal +``` + +The `affine.for` operation represents an affine loop nest. It has one region +containing its body. This region must contain one block that terminates with +[`affine.terminator`](#affineterminator-operation). *Note:* when `affine.for` is +printed in custom format, the terminator is omitted. The block has one argument +of [`index`](../LangRef.md#index-type) type that represents the induction +variable of the loop. + +The `affine.for` operation executes its body a number of times iterating from a +lower bound to an upper bound by a stride. The stride, represented by `step`, is +a positive constant integer which defaults to "1" if not present. The lower and +upper bounds specify a half-open range: the range includes the lower bound but +does not include the upper bound. + +The lower and upper bounds of a `affine.for` operation are represented as an +application of an affine mapping to a list of SSA values passed to the map. The +[same restrictions](#restrictions-on-dimensions-and-symbols) hold for these SSA +values as for all bindings of SSA values to dimensions and symbols. + +The affine mappings for the bounds may return multiple results, in which case +the `max`/`min` keywords are required (for the lower/upper bound respectively), +and the bound is the maximum/minimum of the returned values. There is no +semantic ambiguity, but MLIR syntax requires the use of these keywords to make +things more obvious to human readers. + +Many upper and lower bounds are simple, so MLIR accepts two custom form +syntaxes: the form that accepts a single 'ssa-id' (e.g. `%N`) is shorthand for +applying that SSA value to a function that maps a single symbol to itself, e.g., +`()[s]->(s)()[%N]`. The integer literal form (e.g. `-42`) is shorthand for a +nullary mapping function that returns the constant value (e.g. `()->(-42)()`). + +Example showing reverse iteration of the inner loop: + +```mlir +#map57 = (d0)[s0] -> (s0 - d0 - 1) + +func @simple_example(%A: memref, %B: memref) { + %N = dim %A, 0 : memref + affine.for %i = 0 to %N step 1 { + affine.for %j = 0 to %N { // implicitly steps by 1 + %0 = affine.apply #map57(%j)[%N] + %tmp = call @F1(%A, %i, %0) : (memref, index, index)->(f32) + call @F2(%tmp, %B, %i, %0) : (f32, memref, index, index)->() + } + } + return +} +``` + +#### 'affine.if' operation + +Syntax: + +``` +operation ::= `affine.if` if-op-cond `{` op* `}` (`else` `{` op* `}`)? +if-op-cond ::= integer-set dim-and-symbol-use-list +``` + +The `affine.if` operation restricts execution to a subset of the loop iteration +space defined by an integer set (a conjunction of affine constraints). A single +`affine.if` may end with an optional `else` clause. + +The condition of the `affine.if` is represented by an +[integer set](#integer-sets) (a conjunction of affine constraints), +and the SSA values bound to the dimensions and symbols in the integer set. The +[same restrictions](#restrictions-on-dimensions-and-symbols) hold for these SSA +values as for all bindings of SSA values to dimensions and symbols. + +The `affine.if` operation contains two regions for the "then" and "else" +clauses. The latter may be empty (i.e. contain no blocks), meaning the absence +of the else clause. When non-empty, both regions must contain exactly one block +terminating with [`affine.terminator`](#affineterminator-operation). *Note:* +when `affine.if` is printed in custom format, the terminator is omitted. These +blocks must not have any arguments. + +Example: + +```mlir +#set = (d0, d1)[s0]: (d0 - 10 >= 0, s0 - d0 - 9 >= 0, + d1 - 10 >= 0, s0 - d1 - 9 >= 0) +func @reduced_domain_example(%A, %X, %N) : (memref<10xi32>, i32, i32) { + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %0 = affine.apply #map42(%j) + %tmp = call @S1(%X, %i, %0) + affine.if #set(%i, %j)[%N] { + %1 = affine.apply #map43(%i, %j) + call @S2(%tmp, %A, %i, %1) + } + } + } + return +} +``` + +#### 'affine.load' operation + +Syntax: + +``` +operation ::= ssa-id `=` `affine.load` ssa-use `[` multi-dim-affine-map-of-ssa-ids `]` `:` memref-type +``` + +The `affine.load` op reads an element from a memref, where the index for each +memref dimension is an affine expression of loop induction variables and +symbols. The output of 'affine.load' is a new value with the same type as the +elements of the memref. An affine expression of loop IVs and symbols must be +specified for each dimension of the memref. The keyword 'symbol' can be used to +indicate SSA identifiers which are symbolic. + +Example: + +```mlir + + Example 1: + + %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> + + Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. + + %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] + : memref<100x100xf32> + +``` + +#### 'affine.store' operation + +Syntax: + +``` +operation ::= ssa-id `=` `affine.store` ssa-use, ssa-use `[` multi-dim-affine-map-of-ssa-ids `]` `:` memref-type +``` + +The `affine.store` op writes an element to a memref, where the index for each +memref dimension is an affine expression of loop induction variables and +symbols. The 'affine.store' op stores a new value which is the same type as the +elements of the memref. An affine expression of loop IVs and symbols must be +specified for each dimension of the memref. The keyword 'symbol' can be used to +indicate SSA identifiers which are symbolic. + +Example: + +```mlir + + Example 1: + + affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> + + Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. + + affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] + : memref<100x100xf32> + +``` + +#### 'affine.dma_start' operation + +Syntax: + +``` +operation ::= `affine.dma_Start` ssa-use `[` multi-dim-affine-map-of-ssa-ids `]`, `[` multi-dim-affine-map-of-ssa-ids `]`, `[` multi-dim-affine-map-of-ssa-ids `]`, ssa-use `:` memref-type +``` + +The `affine.dma_start` op starts a non-blocking DMA operation that transfers +data from a source memref to a destination memref. The source and destination +memref need not be of the same dimensionality, but need to have the same +elemental type. The operands include the source and destination memref's +each followed by its indices, size of the data transfer in terms of the +number of elements (of the elemental type of the memref), a tag memref with +its indices, and optionally at the end, a stride and a +number_of_elements_per_stride arguments. The tag location is used by an +AffineDmaWaitOp to check for completion. The indices of the source memref, +destination memref, and the tag memref have the same restrictions as any +affine.load/store. In particular, index for each memref dimension must be an +affine expression of loop induction variables and symbols. +The optional stride arguments should be of 'index' type, and specify a +stride for the slower memory space (memory space with a lower memory space +id), transferring chunks of number_of_elements_per_stride every stride until +%num_elements are transferred. Either both or no stride arguments should be +specified. The value of 'num_elements' must be a multiple of +'number_of_elements_per_stride'. + + +Example: + +```mlir + +For example, a DmaStartOp operation that transfers 256 elements of a memref +'%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory +space 1 at indices [%k + 7, %l], would be specified as follows: + + %num_elements = constant 256 + %idx = constant 0 : index + %tag = alloc() : memref<1xi32, 4> + affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], + %num_elements : + memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> + + If %stride and %num_elt_per_stride are specified, the DMA is expected to + transfer %num_elt_per_stride elements every %stride elements apart from + memory space 0 until %num_elements are transferred. + + affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, + %stride, %num_elt_per_stride : ... + +``` + +#### 'affine.dma_wait' operation + +Syntax: + +``` +operation ::= `affine.dma_Start` ssa-use `[` multi-dim-affine-map-of-ssa-ids `]`, `[` multi-dim-affine-map-of-ssa-ids `]`, `[` multi-dim-affine-map-of-ssa-ids `]`, ssa-use `:` memref-type +``` + +The `affine.dma_start` op blocks until the completion of a DMA operation +associated with the tag element '%tag[%index]'. %tag is a memref, and %index +has to be an index with the same restrictions as any load/store index. +In particular, index for each memref dimension must be an affine expression of +loop induction variables and symbols. %num_elements is the number of elements +associated with the DMA operation. For example: + +Example: + +```mlir + + affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : + memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> + ... + ... + affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> + +``` + +#### 'affine.min' operation + +Syntax: + +``` +operation ::= ssa-id `=` `affine.min` affine-map dim-and-symbol-use-list +``` + +The `affine.min` operation applies an +[affine mapping](#affine-expressions) to a list of SSA values, and returns the +minimum value of all result expressions. The number of dimension and symbol +arguments to affine.min must be equal to the respective number of dimensional +and symbolic inputs to the affine mapping; the `affine.min` operation always +returns one value. The input operands and result must all have 'index' type. + +Example: + +```mlir + +%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1] + +``` + +#### `affine.terminator` operation + +Syntax: + +``` +operation ::= `"affine.terminator"() : () -> ()` +``` + +Affine terminator is a special terminator operation for blocks inside affine +loops ([`affine.for`](#affinefor-operation)) and branches +([`affine.if`](#affineif-operation)). It unconditionally transmits the control +flow to the successor of the operation enclosing the region. + +*Rationale*: bodies of affine operations are [blocks](../LangRef.md#blocks) that +must have terminators. Loops and branches represent structured control flow and +should not accept arbitrary branches as terminators. + +This operation does _not_ have a custom syntax. However, affine control +operations omit the terminator in their custom syntax for brevity. diff --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md new file mode 100644 index 0000000000000000000000000000000000000000..7dcd8f6053c420add42a84147d9bbffb35699b91 --- /dev/null +++ b/mlir/docs/Dialects/GPU.md @@ -0,0 +1,132 @@ +# GPU Dialect + +Note: this dialect is more likely to change than others in the near future; use +with caution. + +This dialect provides middle-level abstractions for launching GPU kernels +following a programming model similar to that of CUDA or OpenCL. It provides +abstractions for kernel invocations (and may eventually provide those for device +management) that are not present at the lower level (e.g., as LLVM IR intrinsics +for GPUs). Its goal is to abstract away device- and driver-specific +manipulations to launch a GPU kernel and provide a simple path towards GPU +execution from MLIR. It may be targeted, for example, by DSLs using MLIR. The +dialect uses `gpu` as its canonical prefix. + +## Memory attribution + +Memory buffers are defined at the function level, either in "gpu.launch" or in +"gpu.func" ops. This encoding makes it clear where the memory belongs and makes +the lifetime of the memory visible. The memory is only accessible while the +kernel is launched/the function is currently invoked. The latter is more strict +than actual GPU implementations but using static memory at the function level is +just for convenience. It is also always possible to pass pointers to the +workgroup memory into other functions, provided they expect the correct memory +space. + +The buffers are considered live throughout the execution of the GPU function +body. The absence of memory attribution syntax means that the function does not +require special buffers. Rationale: although the underlying models declare +memory buffers at the module level, we chose to do it at the function level to +provide some structuring for the lifetime of those buffers; this avoids the +incentive to use the buffers for communicating between different kernels or +launches of the same kernel, which should be done through function arguments +instead; we chose not to use `alloca`-style approach that would require more +complex lifetime analysis following the principles of MLIR that promote +structure and representing analysis results in the IR. + +## Operations + +### `gpu.block_dim` + +Returns the number of threads in the thread block (aka the block size) along the +x, y, or z `dimension`. + +Example: + +```mlir + %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) +``` + +### `gpu.block_id` + +Returns the block id, i.e. the index of the current block within the grid along +the x, y, or z `dimension`. + +Example: + +```mlir + %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) +``` + +### `gpu.grid_dim` + +Returns the number of thread blocks in the grid along the x, y, or z +`dimension`. + +Example: + +```mlir + %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) +``` + +### `gpu.thread_id` + +Returns the thread id, i.e. the index of the current thread within the block +along the x, y, or z `dimension`. + +Example: + +```mlir + %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) +``` + +### `gpu.yield` + +Is a special terminator operation for blocks inside regions in gpu ops. It +returns values to the immediately enclosing gpu op. + +Example: + +```mlir +gpu.yield %f0, %f1 : f32, f32 +``` + +### `gpu.all_reduce` + +The "all_reduce" op reduces the value of every work item across a local +workgroup. The result is equal for all work items of a workgroup. + +For example, both + +```mlir +%1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32) +%2 = "gpu.all_reduce"(%0) ({ +^bb(%lhs : f32, %rhs : f32): + %sum = addf %lhs, %rhs : f32 + "gpu.yield"(%sum) : (f32) -> () +}) : (f32) -> (f32) +``` + +compute the sum of each work item's %0 value. The first version specifies the +accumulation as operation, whereas the second version specifies the accumulation +as code region. The accumulation operation must either be `add` or `mul`. + +Either none or all work items of a workgroup need to execute this op +in convergence. + +### `gpu.barrier` + +The "barrier" op synchronizes all work items of a workgroup. It is used +to coordinate communication between the work items of the workgroup. + +```mlir +gpu.barrier +``` + +waits until all work items in the workgroup have reached this point and all +memory accesses made by these work items prior to the op are visible to all work +items in the workgroup. Data hazards between work items accessing the same +memory can be avoided by synchronizing work items in-between these accesses. + +Either none or all work items of a workgroup need to execute this op +in convergence. diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md new file mode 100644 index 0000000000000000000000000000000000000000..00d0fa02fece6c4de8fd4a17c38b366ae280e7d3 --- /dev/null +++ b/mlir/docs/Dialects/LLVM.md @@ -0,0 +1,429 @@ +# LLVM IR Dialect + +This dialect wraps the LLVM IR types and instructions into MLIR types and +operations. It provides several additional operations that are necessary to +cover for the differences in the IR structure (e.g., MLIR does not have `phi` +operations and LLVM IR does not have a `constant` operation). + +In this document, we use "LLVM IR" to designate the +[intermediate representation of LLVM](https://llvm.org/docs/LangRef.html) and +"LLVM IR _dialect_" to refer to the MLIR dialect reflecting LLVM instructions +and types. + +[TOC] + +## Context and Module Association + +The LLVM IR dialect object _contains_ an LLVM Context and an LLVM Module that it +uses to define, print, parse and manage LLVM IR types. These objects can be +obtained from the dialect object using `.getLLVMContext()` and +`getLLVMModule()`. All LLVM IR objects that interact with the LLVM IR dialect +must exist in the dialect's context. + +## Types + +The LLVM IR dialect defines a single MLIR type, `LLVM::LLVMType`, that can wrap +any existing LLVM IR type. Its syntax is as follows + +``` +type ::= `!llvm<"` llvm-canonical-type `"> +llvm-canonical-type ::= +``` + +For example, one can use primitive types `!llvm.i32`, pointer types +`!llvm<"i8*">`, vector types `!llvm<"<4 x float>">` or structure types +`!llvm<"{i32, float}">`. The parsing and printing of the canonical form is +delegated to the LLVM assembly parser and printer. + +LLVM IR dialect types contain an `llvm::Type*` object that can be obtained by +calling `.getUnderlyingType()` and used in LLVM API calls directly. These +objects are allocated within the LLVM context associated with the LLVM IR +dialect and may be linked to the properties of the associated LLVM module. + +LLVM IR dialect type can be constructed from any `llvm::Type*` that is +associated with the LLVM context of the dialect. In this document, we use the +term "wrapped LLVM IR type" to refer to the LLVM IR dialect type containing a +specific LLVM IR type. + +## Operations + +All operations in the LLVM IR dialect have a custom form in MLIR. The mnemonic +of an operation is that used in LLVM IR prefixed with "`llvm.`". + +### LLVM functions + +MLIR functions are defined by an operation that is not built into the IR itself. +The LLVM IR dialect provides an `llvm.func` operation to define functions +compatible with LLVM IR. These functions have wrapped LLVM IR function type but +use MLIR syntax to express it. They are required to have exactly one result +type. LLVM function operation is intended to capture additional properties of +LLVM functions, such as linkage and calling convention, that may be modeled +differently by the built-in MLIR function. + +```mlir +// The type of @bar is !llvm<"i64 (i64)"> +llvm.func @bar(%arg0: !llvm.i64) -> !llvm.i64 { + llvm.return %arg0 : !llvm.i64 +} + +// Type type of @foo is !llvm<"void (i64)"> +// !llvm.void type is omitted +llvm.func @foo(%arg0: !llvm.i64) { + llvm.return +} + +// A function with `internal` linkage. +llvm.func internal @internal_func() { + llvm.return +} + +``` + +### LLVM IR operations + +The following operations are currently supported. The semantics of these +operations corresponds to the semantics of the similarly-named LLVM IR +instructions. + +#### Integer binary arithmetic operations + +Take two arguments of wrapped LLVM IR integer type, produce one value of the +same type. + +- `add` +- `sub` +- `mul` +- `udiv` +- `sdiv` +- `urem` +- `srem` + +Examples: + +```mlir +// Integer addition. +%0 = llvm.add %a, %b : !llvm.i32 + +// Unsigned integer division. +%1 = llvm.udiv %a, %b : !llvm.i32 +``` + +#### Floating point binary arithmetic operations + +Take two arguments of wrapped LLVM IR floating point type, produce one value of +the same type. + +- `fadd` +- `fsub` +- `fmul` +- `fdiv` +- `frem` + +Examples: + +```mlir +// Float addition. +%0 = llvm.fadd %a, %b : !llvm.float + +// Float division. +%1 = llvm.fdiv %a, %b : !llvm.float +``` + +#### Memory-related operations + +- ` = alloca x ` +- ` = getelementptr
[ (, )+]` +- ` = load
` +- `store ,
` + +In these operations, `` must be a value of wrapped LLVM IR integer type, +`
` must be a value of wrapped LLVM IR pointer type, and `` must +be a value of wrapped LLVM IR type that corresponds to the pointer type of +`
`. + +The `index` operands are integer values whose semantics is identical to the +non-pointer arguments of LLVM IR's `getelementptr`. + +Examples: + +```mlir +// Allocate an array of 4 floats on stack +%c4 = llvm.mlir.constant(4) : !llvm.i64 +%0 = llvm.alloca %c4 x !llvm.float : (!llvm.i64) -> !llvm<"float*"> + +// Get the second element of the array (note 0-based indexing). +%c1 = llvm.mlir.constant(1) : !llvm.i64 +%1 = llvm.getelementptr %0[%c1] : (!llvm<"float*">, !llvm.i64) + -> !llvm<"float*"> + +// Store a constant into this element. +%cf = llvm.mlir.constant(42.0 : f32) : !llvm.float +llvm.store %cf, %1 : !llvm<"float*"> + +// Load the value from this element. +%3 = llvm.load %1 : !llvm<"float*"> +``` + +#### Operations on values of aggregate type. + +- ` = extractvalue [ (, )+]` +- ` = insertvalue , [ (, )+]` + +In these operations, `` must be a value of wrapped LLVM IR structure +type and `` must be a value that corresponds to one of the (nested) +structure element types. + +Note the use of integer literals to designate subscripts, which is made possible +by `extractvalue` and `insertvalue` must have constant subscripts. Internally, +they are modeled as array attributes. + +Examples: + +```mlir +// Get the value third element of the second element of a structure. +%0 = llvm.extractvalue %s[1, 2] : !llvm<"{i32, {i1, i8, i16}"> + +// Insert the value to the third element of the second element of a structure. +// Note that this returns a new structure-typed value. +%1 = llvm.insertvalue %0, %s[1, 2] : !llvm<"{i32, {i1, i8, i16}"> +``` + +#### Terminator operations. + +Branch operations: + +- `br [()]` +- `cond_br [(),` + `()]` + +In order to comply with MLIR design, branch operations in the LLVM IR dialect +pass arguments to basic blocks. Successors must be valid block MLIR identifiers +and operand lists for each of them must have the same types as the arguments of +the respective blocks. `` must be a wrapped LLVM IR `i1` type. + +Since LLVM IR uses the name of the predecessor basic block to identify the +sources of a PHI node, it is invalid for two entries of the PHI node to indicate +different values coming from the same block. Therefore, `cond_br` in the LLVM IR +dialect disallows its successors to be the same block _if_ this block has +arguments. + +Examples: + +```mlir +// Branch without arguments. +^bb0: + llvm.br ^bb0 + +// Branch and pass arguments. +^bb1(%arg: !llvm.i32): + llvm.br ^bb1(%arg : !llvm.i32) + +// Conditionally branch and pass arguments to one of the blocks. +llvm.cond_br %cond, ^bb0, %bb1(%arg : !llvm.i32) + +// It's okay to use the same block without arguments, but probably useless. +llvm.cond_br %cond, ^bb0, ^bb0 + +// ERROR: Passing different arguments to the same block in a conditional branch. +llvm.cond_br %cond, ^bb1(%0 : !llvm.i32), ^bb1(%1 : !llvm.i32) + +``` + +Call operations: + +- ` = call()` +- `call()` + +In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect implements +this behavior by providing a variadic `call` operation for 0- and 1-result +functions. Even though MLIR supports multi-result functions, LLVM IR dialect +disallows them. + +The `call` instruction supports both direct and indirect calls. Direct calls +start with a function name (`@`-prefixed) and indirect calls start with an SSA +value (`%`-prefixed). The direct callee, if present, is stored as a function +attribute `callee`. The trailing type of the instruction is always the MLIR +function type, which may be different from the indirect callee that has the +wrapped LLVM IR function type. + +Examples: + +```mlir +// Direct call without arguments and with one result. +%0 = llvm.call @foo() : () -> (!llvm.float) + +// Direct call with arguments and without a result. +llvm.call @bar(%0) : (!llvm.float) -> () + +// Indirect call with an argument and without a result. +llvm.call %1(%0) : (!llvm.float) -> () +``` + +#### Miscellaneous operations. + +Integer comparisons: `icmp "predicate" , `. The following predicate +values are supported: + +- `eq` - equality comparison; +- `ne` - inequality comparison; +- `slt` - signed less-than comparison +- `sle` - signed less-than-or-equal comparison +- `sgt` - signed greater-than comparison +- `sge` - signed greater-than-or-equal comparison +- `ult` - unsigned less-than comparison +- `ule` - unsigned less-than-or-equal comparison +- `ugt` - unsigned greater-than comparison +- `uge` - unsigned greater-than-or-equal comparison + +Bitwise reinterpretation: `bitcast `. + +Selection: `select , , `. + +### Auxiliary MLIR operations + +These operations do not have LLVM IR counterparts but are necessary to map LLVM +IR into MLIR. They should be prefixed with `llvm.mlir`. + +#### `llvm.mlir.addressof` + +Creates an SSA value containing a pointer to a global variable or constant +defined by `llvm.mlir.global`. The global value can be defined after its first +referenced. If the global value is a constant, storing into it is not allowed. + +Examples: + +```mlir +func @foo() { + // Get the address of a global. + %0 = llvm.mlir.addressof @const : !llvm<"i32*"> + + // Use it as a regular pointer. + %1 = llvm.load %0 : !llvm<"i32*"> +} + +// Define the global. +llvm.mlir.global @const(42 : i32) : !llvm.i32 +``` + +#### `llvm.mlir.constant` + +Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all +constants must be created as SSA values before being used in other operations. +`llvm.mlir.constant` creates such values for scalars and vectors. It has a +mandatory `value` attribute, which may be an integer, floating point attribute; +dense or sparse attribute containing integers or floats. The type of the +attribute is one the corresponding MLIR standard types. It may be omitted for +`i64` and `f64` types that are implied. The operation produces a new SSA value +of the specified LLVM IR dialect type. The type of that value _must_ correspond +to the attribute type converted to LLVM IR. + +Examples: + +```mlir +// Integer constant, internal i32 is mandatory +%0 = llvm.mlir.constant(42 : i32) : !llvm.i32 + +// It's okay to omit i64. +%1 = llvm.mlir.constant(42) : !llvm.i64 + +// Floating point constant. +%2 = llvm.mlir.constant(42.0 : f32) : !llvm.float + +// Splat dense vector constant. +%3 = llvm.mlir.constant(dense<1.0> : vector<4xf32>) : !llvm<"<4 x float>"> +``` + +#### `llvm.mlir.global` + +Since MLIR allows for arbitrary operations to be present at the top level, +global variables are defined using the `llvm.mlir.global` operation. Both global +constants and variables can be defined, and the value may also be initialized in +both cases. + +There are two forms of initialization syntax. Simple constants that can be +represented as MLIR attributes can be given in-line: + +```mlir +llvm.mlir.global @variable(32.0 : f32) : !llvm.float +``` + +This initialization and type syntax is similar to `llvm.mlir.constant` and may +use two types: one for MLIR attribute and another for the LLVM value. These +types must be compatible. + +More complex constants that cannot be represented as MLIR attributes can be +given in an initializer region: + +```mlir +// This global is initialized with the equivalent of: +// i32* getelementptr (i32* @g2, i32 2) +llvm.mlir.global constant @int_gep() : !llvm<"i32*"> { + %0 = llvm.mlir.addressof @g2 : !llvm<"i32*"> + %1 = llvm.mlir.constant(2 : i32) : !llvm.i32 + %2 = llvm.getelementptr %0[%1] : (!llvm<"i32*">, !llvm.i32) -> !llvm<"i32*"> + // The initializer region must end with `llvm.return`. + llvm.return %2 : !llvm<"i32*"> +} +``` + +Only one of the initializer attribute or initializer region may be provided. + +`llvm.mlir.global` must appear at top-level of the enclosing module. It uses an +@-identifier for its value, which will be uniqued by the module with respect to +other @-identifiers in it. + +Examples: + +```mlir +// Global values use @-identifiers. +llvm.mlir.global constant @cst(42 : i32) : !llvm.i32 + +// Non-constant values must also be initialized. +llvm.mlir.global @variable(32.0 : f32) : !llvm.float + +// Strings are expected to be of wrapped LLVM i8 array type and do not +// automatically include the trailing zero. +llvm.mlir.global @string("abc") : !llvm<"[3 x i8]"> + +// For strings globals, the trailing type may be omitted. +llvm.mlir.global constant @no_trailing_type("foo bar") + +// A complex initializer is constructed with an initializer region. +llvm.mlir.global constant @int_gep() : !llvm<"i32*"> { + %0 = llvm.mlir.addressof @g2 : !llvm<"i32*"> + %1 = llvm.mlir.constant(2 : i32) : !llvm.i32 + %2 = llvm.getelementptr %0[%1] : (!llvm<"i32*">, !llvm.i32) -> !llvm<"i32*"> + llvm.return %2 : !llvm<"i32*"> +} +``` + +#### `llvm.mlir.null` + +Unlike LLVM IR, MLIR does not have first-class null pointers. They must be +explicitly created as SSA values using `llvm.mlir.null`. This operation has +operands or attributes, and returns a null value of a wrapped LLVM IR pointer +type. + +Examples: + +```mlir +// Null pointer to i8 value. +%0 = llvm.mlir.null : !llvm<"i8*"> + +// Null pointer to a function with signature void() value. +%1 = llvm.mlir.null : !llvm<"void()*"> +``` + +#### `llvm.mlir.undef` + +Unlike LLVM IR, MLIR does not have first-class undefined values. Such values +must be created as SSA values using `llvm.mlir.undef`. This operation has no +operands or attributes. It creates an undefined value of the specified LLVM IR +dialect type wrapping an LLVM IR structure type. + +Example: + +```mlir +// Create a structure with a 32-bit integer followed by a float. +%0 = llvm.mlir.undef : !llvm<"{i32, float}"> +``` diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md new file mode 100644 index 0000000000000000000000000000000000000000..1ed5a2c2a2641ef96072b8eadb1351a880d71354 --- /dev/null +++ b/mlir/docs/Dialects/Linalg.md @@ -0,0 +1,8 @@ +# Linalg Dialect + +To generate the documentation: + +```sh +mlir-tblgen --gen-op-doc -I /path/to/mlir/include \ +/path/to/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td +``` diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md new file mode 100644 index 0000000000000000000000000000000000000000..1d72e5449d3e846fe9b6d691912d84eee313b229 --- /dev/null +++ b/mlir/docs/Dialects/SPIR-V.md @@ -0,0 +1,1039 @@ +# SPIR-V Dialect + +This document describes the design of the SPIR-V dialect in MLIR. It lists +various design choices we made for modeling different SPIR-V mechanisms, and +their rationale. + +This document also explains in a high-level manner how different components are +organized and implemented in the code and gives steps to follow for extending +them. + +This document assumes familiarity with SPIR-V. [SPIR-V][Spirv] is the Khronos +Group’s binary intermediate language for representing graphics shaders and +compute kernels. It is adopted by multiple Khronos Group’s APIs, including +Vulkan and OpenCL. It is fully defined in a +[human-readable specification][SpirvSpec]; the syntax of various SPIR-V +instructions are encoded in a [machine-readable grammar][SpirvGrammar]. + +## Design Guidelines + +SPIR-V is a binary intermediate language that serves dual purpose: on one side, +it is an intermediate language to represent graphics shaders and compute kernels +for high-level languages to target; on the other side, it defines a stable +binary format for hardware driver consumption. As a result, SPIR-V has design +principles pertain to not only intermediate language, but also binary format. +For example, regularity is one of the design goals of SPIR-V. All concepts are +represented as SPIR-V instructions, including declaring extensions and +capabilities, defining types and constants, defining functions, attaching +additional properties to computation results, etc. This way favors binary +encoding and decoding for driver consumption but not necessarily compiler +transformations. + +### Dialect design principles + +The main objective of the SPIR-V dialect is to be a proper intermediate +representation (IR) to facilitate compiler transformations. While we still aim +to support serializing to and deserializing from the binary format for various +good reasons, the binary format and its concerns play less a role in the design +of the SPIR-V dialect: when there is a trade-off to be made between favoring IR +and supporting binary format, we lean towards the former. + +On the IR aspect, the SPIR-V dialect aims to model SPIR-V at the same semantic +level. It is not intended to be a higher level or lower level abstraction than +the SPIR-V specification. Those abstractions are easily outside the domain of +SPIR-V and should be modeled with other proper dialects so they can be shared +among various compilation paths. Because of the dual purpose of SPIR-V, SPIR-V +dialect staying at the same semantic level as the SPIR-V specification also +means we can still have straightforward serailization and deserailization for +the majority of functionalities. + +To summarize, the SPIR-V dialect follows the following design principles: + +* Stay as the same semantic level as the SPIR-V specification by having + one-to-one mapping for most concepts and entities. +* Adopt SPIR-V specification's syntax if possible, but deviate intentionally + to utilize MLIR mechanisms if it results in better representation and + benefits transformation. +* Be straightforward to serialize into and deserialize from the SPIR-V binary + format. + +SPIR-V is designed to be consumed by hardware drivers, so its representation is +quite clear, yet verbose for some cases. Allowing representational deviation +gives us the flexibility to reduce the verbosity by using MLIR mechanisms. + +### Dialect scopes + +SPIR-V supports multiple execution environments, specified by client APIs. +Notable adopters include Vulkan and OpenCL. It follows that the SPIR-V dialect +should support multiple execution environments if to be a proper proxy of SPIR-V +in MLIR systems. The SPIR-V dialect is designed with these considerations: it +has proper support for versions, extensions, and capabilities and is as +extensible as SPIR-V specification. + +## Conventions + +The SPIR-V dialect adopts the following conventions for IR: + +* The prefix for all SPIR-V types and operations are `spv.`. +* All instructions in an extended instruction set are further qualified with + the extended instruction set's prefix. For example, all operations in the + GLSL extended instruction set is has the prefix of `spv.GLSL.`. +* Ops that directly mirror instructions in the specification have `CamelCase` + names that are the same as the instruction opnames (without the `Op` + prefix). For example, `spv.FMul` is a direct mirror of `OpFMul` in the + specification. Such an op will be serialized into and deserialized from one + SPIR-V instruction. +* Ops with `snake_case` names are those that have different representation + from corresponding instructions (or concepts) in the specification. These + ops are mostly for defining the SPIR-V structure. For example, `spv.module` + and `spv.constant`. They may correspond to one or more instructions during + (de)serialization. +* Ops with `_snake_case` names are those that have no corresponding + instructions (or concepts) in the binary format. They are introduced to + satisfy MLIR structural requirements. For example, `spv._module_end` and + `spv._merge`. They maps to no instructions during (de)serialization. + +(TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for +them.) + +## Module + +A SPIR-V module is defined via the `spv.module` op, which has one region that +contains one block. Model-level instructions, including function definitions, +are all placed inside the block. Functions are defined using the builtin `func` +op. + +We choose to model a SPIR-V module with a dedicated `spv.module` op based on the +following considerations: + +* It maps cleanly to a SPIR-V module in the specification. +* We can enforce SPIR-V specific verification that is suitable to be performed + at the module-level. +* We can attach additional model-level attributes. +* We can control custom assembly form. + +The `spv.module` op's region cannot capture SSA values from outside, neither +implicitly nor explicitly. The `spv.module` op's region is closed as to what ops +can appear inside: apart from the builtin `func` op, it can only contain ops +from the SPIR-V dialect. The `spv.module` op's verifier enforces this rule. This +meaningfully guarantees that a `spv.module` can be the entry point and boundary +for serialization. + +### Module-level operations + +SPIR-V binary format defines the following [sections][SpirvLogicalLayout]: + +1. Capabilities required by the module. +1. Extensions required by the module. +1. Extended instructions sets required by the module. +1. Addressing and memory model specification. +1. Entry point specifications. +1. Execution mode declarations. +1. Debug instructions. +1. Annotation/decoration instructions. +1. Type, constant, global variables. +1. Function declarations. +1. Function definitions. + +Basically, a SPIR-V binary module contains multiple module-level instructions +followed by a list of functions. Those module-level instructions are essential +and they can generate result ids referenced by functions, notably, declaring +resource variables to interact with the execution environment. + +Compared to the binary format, we adjust how these module-level SPIR-V +instructions are represented in the SPIR-V dialect: + +#### Use MLIR attributes for metadata + +* Requirements for capabilities, extensions, extended instruction sets, + addressing model, and memory model is conveyed using `spv.module` + attributes. This is considered better because these information are for the + execution environment. It's easier to probe them if on the module op itself. +* Annotations/decoration instructions are "folded" into the instructions they + decorate and represented as attributes on those ops. This eliminates + potential forward references of SSA values, improves IR readability, and + makes querying the annotations more direct. More discussions can be found in + the [`Decorations`](#decorations) section. + +#### Model types with MLIR custom types + +* Types are represented using MLIR standard types and SPIR-V dialect specific + types. There are no type declaration ops in the SPIR-V dialect. More + discussions can be found in the [Types](#types) section later. + +#### Unify and localize constants + +* Various normal constant instructions are represented by the same + `spv.constant` op. Those instructions are just for constants of different + types; using one op to represent them reduces IR verbosity and makes + transformations less tedious. +* Normal constants are not placed in `spv.module`'s region; they are localized + into functions. This is to make functions in the SPIR-V dialect to be + isolated and explicit capturing. Constants are cheap to duplicate given + attributes are uniqued in `MLIRContext`. + +#### Adopt symbol-based global variables and specialization constant + +* Global variables are defined with the `spv.globalVariable` op. They do not + generate SSA values. Instead they have symbols and should be referenced via + symbols. To use a global variables in a function block, `spv._address_of` is + needed to turn the symbol into a SSA value. +* Specialization constants are defined with the `spv.specConstant` op. Similar + to global variables, they do not generate SSA values and have symbols for + reference, too. `spv._reference_of` is needed to turn the symbol into a SSA + value for use in a function block. + +The above choices enables functions in the SPIR-V dialect to be isolated and +explicit capturing. + +#### Disallow implicit capturing in functions + +* In SPIR-V specification, functions support implicit capturing: they can + reference SSA values defined in modules. In the SPIR-V dialect functions are + defined with `func` op, which disallows implicit capturing. This is more + friendly to compiler analyses and transformations. More discussions can be + found in the [Function](#function) section later. + +### Model entry points and execution models as normal ops + +* A SPIR-V module can have multiple entry points. And these entry points refer + to the function and interface variables. It’s not suitable to model them as + `spv.module` op attributes. We can model them as normal ops of using symbol + references. +* Similarly for execution modes, which are coupled with entry points, we can + model them as normal ops in `spv.module`'s region. + +## Decorations + +Annotations/decorations provide additional information on result ids. In SPIR-V, +all instructions can generate result ids, including value-computing and +type-defining ones. + +For decorations on value result ids, we can just have a corresponding attribute +attached to the operation generating the SSA value. For example, for the +following SPIR-V: + +```spirv +OpDecorate %v1 RelaxedPrecision +OpDecorate %v2 NoContraction +... +%v1 = OpFMul %float %0 %0 +%v2 = OpFMul %float %1 %1 +``` + +We can represent them in the SPIR-V dialect as: + +```mlir +%v1 = "spv.FMul"(%0, %0) {RelaxedPrecision: unit} : (f32, f32) -> (f32) +%v2 = "spv.FMul"(%1, %1) {NoContraction: unit} : (f32, f32) -> (f32) +``` + +This approach benefits transformations. Essentially those decorations are just +additional properties of the result ids (and thus their defining instructions). +In SPIR-V binary format, they are just represented as instructions. Literally +following SPIR-V binary format means we need to through def-use chains to find +the decoration instructions and query information from them. + +For decorations on type result ids, notice that practically, only result ids +generated from composite types (e.g., `OpTypeArray`, `OpTypeStruct`) need to be +decorated for memory layouting purpose (e.g., `ArrayStride`, `Offset`, etc.); +scalar/vector types are required to be uniqued in SPIR-V. Therefore, we can just +encode them directly in the dialect-specific type. + +## Types + +Theoretically we can define all SPIR-V types using MLIR extensible type system, +but other than representational purity, it does not buy us more. Instead, we +need to maintain the code and invest in pretty printing them. So we prefer to +use builtin/standard types if possible. + +The SPIR-V dialect reuses standard integer, float, and vector types: + +Specification | Dialect +:----------------------------------: | :-------------------------------: +`OpTypeBool` | `i1` +`OpTypeInt ` | `i` +`OpTypeFloat ` | `f` +`OpTypeVector ` | `vector< x >` + +Similarly, `mlir::NoneType` can be used for SPIR-V `OpTypeVoid`; builtin +function types can be used for SPIR-V `OpTypeFunction` types. + +The SPIR-V dialect and defines the following dialect-specific types: + +``` +spirv-type ::= array-type + | image-type + | pointer-type + | runtime-array-type + | struct-type +``` + +### Array type + +This corresponds to SPIR-V [array type][ArrayType]. Its syntax is + +``` +element-type ::= integer-type + | floating-point-type + | vector-type + | spirv-type + +array-type ::= `!spv.array<` integer-literal `x` element-type `>` +``` + +For example, + +```mlir +!spv.array<4 x i32> +!spv.array<16 x vector<4 x f32>> +``` + +### Image type + +This corresponds to SPIR-V [image type][ImageType]. Its syntax is + +``` +dim ::= `1D` | `2D` | `3D` | `Cube` | + +depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` + +arrayed-info ::= `NonArrayed` | `Arrayed` + +sampling-info ::= `SingleSampled` | `MultiSampled` + +sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` + +format ::= `Unknown` | `Rgba32f` | + +image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` + arrayed-info `,` sampling-info `,` + sampler-use-info `,` format `>` +``` + +For example, + +```mlir +!spv.image +!spv.image +``` + +### Pointer type + +This corresponds to SPIR-V [pointer type][PointerType]. Its syntax is + +``` +storage-class ::= `UniformConstant` + | `Uniform` + | `Workgroup` + | + +pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>` +``` + +For example, + +```mlir +!spv.ptr +!spv.ptr, Uniform> +``` + +### Runtime array type + +This corresponds to SPIR-V [runtime array type][RuntimeArrayType]. Its syntax is + +``` +runtime-array-type ::= `!spv.rtarray<` element-type `>` +``` + +For example, + +```mlir +!spv.rtarray +!spv.rtarray> +``` + +### Struct type + +This corresponds to SPIR-V [struct type][StructType]. Its syntax is + +``` +struct-member-decoration ::= integer-literal? spirv-decoration* +struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)? + (`, ` spirv-type (`[` struct-member-decoration `]`)? +``` + +For Example, + +```mlir +!spv.struct +!spv.struct +!spv.struct> +!spv.struct +``` + +## Function + +In SPIR-V, a function construct consists of multiple instructions involving +`OpFunction`, `OpFunctionParameter`, `OpLabel`, `OpFunctionEnd`. + +```spirv +// int f(int v) { return v; } +%1 = OpTypeInt 32 0 +%2 = OpTypeFunction %1 %1 +%3 = OpFunction %1 %2 +%4 = OpFunctionParameter %1 +%5 = OpLabel +%6 = OpReturnValue %4 + OpFunctionEnd +``` + +This construct is very clear yet quite verbose. It is intended for driver +consumption. There is little benefit to literally replicate this construct in +the SPIR-V dialect. Instead, we reuse the builtin `func` op to express functions +more concisely: + +```mlir +func @f(%arg: i32) -> i32 { + "spv.ReturnValue"(%arg) : (i32) -> (i32) +} +``` + +A SPIR-V function can have at most one result. It cannot contain nested +functions or non-SPIR-V operations. `spv.module` verifies these requirements. + +A major difference between the SPIR-V dialect and the SPIR-V specification for +functions is that the former are isolated and require explicit capturing, while +the latter allow implicit capturing. In SPIR-V specification, functions can +refer to SSA values (generated by constants, global variables, etc.) defined in +modules. The SPIR-V dialect adjusted how constants and global variables are +modeled to enable isolated functions. Isolated functions are more friendly to +compiler analyses and transformations. This also enables the SPIR-V dialect to +better utilize core infrastructure: many functionalities in the core +infrastructure requires ops to be isolated, e.g., the +[greedy pattern rewriter][GreedyPatternRewriter] can only act on ops isolated +from above. + +(TODO: create a dedicated `spv.fn` op for SPIR-V functions.) + +## Operations + +In SPIR-V, instruction is a generalized concept; a SPIR-V module is just a +sequence of instructions. Declaring types, expressing computations, annotating +result ids, expressing control flows and others are all in the form of +instructions. + +We only discuss instructions expressing computations here, which can be +represented via SPIR-V dialect ops. Module-level instructions for declarations +and definitions are represented differently in the SPIR-V dialect as explained +earlier in the [Module-level operations](#module-level-operations) section. + +An instruction computes zero or one result from zero or more operands. The +result is a new result id. An operand can be a result id generated by a previous +instruction, an immediate value, or a case of an enum type. We can model result +id operands and results with MLIR SSA values; for immediate value and enum +cases, we can model them with MLIR attributes. + +For example, + +```spirv +%i32 = OpTypeInt 32 0 +%c42 = OpConstant %i32 42 +... +%3 = OpVariable %i32 Function 42 +%4 = OpIAdd %i32 %c42 %c42 +``` + +can be represented in the dialect as + +```mlir +%0 = "spv.constant"() { value = 42 : i32 } : () -> i32 +%1 = "spv.Variable"(%0) { storage_class = "Function" } : (i32) -> !spv.ptr +%2 = "spv.IAdd"(%0, %0) : (i32, i32) -> i32 +``` + +Operation documentation is written in each op's Op Definition Spec using +TableGen. A markdown version of the doc can be generated using `mlir-tblgen +-gen-doc`. + +### Ops from extended instruction sets + +Analogically extended instruction set is a mechanism to import SPIR-V +instructions within another namespace. [`GLSL.std.450`][GlslStd450] is an +extended instruction set that provides common mathematical routines that should +be supported. Instead of modeling `OpExtInstImport` as a separate op and use a +single op to model `OpExtInst` for all extended instructions, we model each +SPIR-V instruction in an extended instruction set as a separate op with the +proper name prefix. For example, for + +```spirv +%glsl = OpExtInstImport "GLSL.std.450" + +%f32 = OpTypeFloat 32 +%cst = OpConstant %f32 ... + +%1 = OpExtInst %f32 %glsl 28 %cst +%2 = OpExtInst %f32 %glsl 31 %cst +``` + +we can have + +```mlir +%1 = "spv.GLSL.Log"(%cst) : (f32) -> (f32) +%2 = "spv.GLSL.Sqrt(%cst) : (f32) -> (f32) +``` + +## Control Flow + +SPIR-V binary format uses merge instructions (`OpSelectionMerge` and +`OpLoopMerge`) to declare structured control flow. They explicitly declare a +header block before the control flow diverges and a merge block where control +flow subsequently converges. These blocks delimit constructs that must nest, and +can only be entered and exited in structured ways. + +In the SPIR-V dialect, we use regions to mark the boundary of a structured +control flow construct. With this approach, it's easier to discover all blocks +belonging to a structured control flow construct. It is also more idiomatic to +MLIR system. + +We introduce a `spv.selection` and `spv.loop` op for structured selections and +loops, respectively. The merge targets are the next ops following them. Inside +their regions, a special terminator, `spv._merge` is introduced for branching to +the merge target. + +### Selection + +`spv.selection` defines a selection construct. It contains one region. The +region should contain at least two blocks: one selection header block and one +merge block. + +* The selection header block should be the first block. It should contain the + `spv.BranchConditional` or `spv.Switch` op. +* The merge block should be the last block. The merge block should only + contain a `spv._merge` op. Any block can branch to the merge block for early + exit. + +``` + +--------------+ + | header block | (may have multiple outgoing branches) + +--------------+ + / | \ + ... + + + +---------+ +---------+ +---------+ + | case #0 | | case #1 | | case #2 | ... (may have branches between each other) + +---------+ +---------+ +---------+ + + + ... + \ | / + v + +-------------+ + | merge block | (may have multiple incoming branches) + +-------------+ +``` + +For example, for the given function + +```c++ +void loop(bool cond) { + int x = 0; + if (cond) { + x = 1; + } else { + x = 2; + } + // ... +} +``` + +It will be represented as + +```mlir +func @selection(%cond: i1) -> () { + %zero = spv.constant 0: i32 + %one = spv.constant 1: i32 + %two = spv.constant 2: i32 + %x = spv.Variable init(%zero) : !spv.ptr + + spv.selection { + spv.BranchConditional %cond, ^then, ^else + + ^then: + spv.Store "Function" %x, %one : i32 + spv.Branch ^merge + + ^else: + spv.Store "Function" %x, %two : i32 + spv.Branch ^merge + + ^merge: + spv._merge + } + + // ... +} + +``` + +### Loop + +`spv.loop` defines a loop construct. It contains one region. The region should +contain at least four blocks: one entry block, one loop header block, one loop +continue block, one merge block. + +* The entry block should be the first block and it should jump to the loop + header block, which is the second block. +* The merge block should be the last block. The merge block should only + contain a `spv._merge` op. Any block except the entry block can branch to + the merge block for early exit. +* The continue block should be the second to last block and it should have a + branch to the loop header block. +* The loop continue block should be the only block, except the entry block, + branching to the loop header block. + +``` + +-------------+ + | entry block | (one outgoing branch) + +-------------+ + | + v + +-------------+ (two incoming branches) + | loop header | <-----+ (may have one or two outgoing branches) + +-------------+ | + | + ... | + \ | / | + v | + +---------------+ | (may have multiple incoming branches) + | loop continue | -----+ (may have one or two outgoing branches) + +---------------+ + + ... + \ | / + v + +-------------+ (may have multiple incoming branches) + | merge block | + +-------------+ +``` + +The reason to have another entry block instead of directly using the loop header +block as the entry block is to satisfy region's requirement: entry block of +region may not have predecessors. We have a merge block so that branch ops can +reference it as successors. The loop continue block here corresponds to +"continue construct" using SPIR-V spec's term; it does not mean the "continue +block" as defined in the SPIR-V spec, which is "a block containing a branch to +an OpLoopMerge instruction’s Continue Target." + +For example, for the given function + +```c++ +void loop(int count) { + for (int i = 0; i < count; ++i) { + // ... + } +} +``` + +It will be represented as + +```mlir +func @loop(%count : i32) -> () { + %zero = spv.constant 0: i32 + %one = spv.constant 1: i32 + %var = spv.Variable init(%zero) : !spv.ptr + + spv.loop { + spv.Branch ^header + + ^header: + %val0 = spv.Load "Function" %var : i32 + %cmp = spv.SLessThan %val0, %count : i32 + spv.BranchConditional %cmp, ^body, ^merge + + ^body: + // ... + spv.Branch ^continue + + ^continue: + %val1 = spv.Load "Function" %var : i32 + %add = spv.IAdd %val1, %one : i32 + spv.Store "Function" %var, %add : i32 + spv.Branch ^header + + ^merge: + spv._merge + } + return +} +``` + +### Block argument for Phi + +There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi` +instructions are modelled as block arguments in the SPIR-V dialect. (See the +[Rationale][Rationale] doc for "Block Arguments vs Phi nodes".) Each block +argument corresponds to one `OpPhi` instruction in the SPIR-V binary format. For +example, for the following SPIR-V function `foo`: + +```spirv + %foo = OpFunction %void None ... +%entry = OpLabel + %var = OpVariable %_ptr_Function_int Function + OpSelectionMerge %merge None + OpBranchConditional %true %true %false + %true = OpLabel + OpBranch %phi +%false = OpLabel + OpBranch %phi + %phi = OpLabel + %val = OpPhi %int %int_1 %false %int_0 %true + OpStore %var %val + OpReturn +%merge = OpLabel + OpReturn + OpFunctionEnd +``` + +It will be represented as: + +```mlir +func @foo() -> () { + %var = spv.Variable : !spv.ptr + + spv.selection { + %true = spv.constant true + spv.BranchConditional %true, ^true, ^false + + ^true: + %zero = spv.constant 0 : i32 + spv.Branch ^phi(%zero: i32) + + ^false: + %one = spv.constant 1 : i32 + spv.Branch ^phi(%one: i32) + + ^phi(%arg: i32): + spv.Store "Function" %var, %arg : i32 + spv.Return + + ^merge: + spv._merge + } + spv.Return +} +``` + +## Shader interface (ABI) + +SPIR-V itself is just expressing computation happening on GPU device. SPIR-V +programs themselves are not enough for running workloads on GPU; a companion +host application is needed to manage the resources referenced by SPIR-V programs +and dispatch the workload. For the Vulkan execution environment, the host +application will be written using Vulkan API. Unlike CUDA, the SPIR-V program +and the Vulkan application are typically authored with different front-end +languages, which isolates these two worlds. Yet they still need to match +_interfaces_: the variables declared in a SPIR-V program for referencing +resources need to match with the actual resources managed by the application +regarding their parameters. + +Still using Vulkan as an example execution environment, there are two primary +resource types in Vulkan: buffers and images. They are used to back various uses +that may differ regarding the classes of operations (load, store, atomic) to be +performed. These uses are differentiated via descriptor types. (For example, +uniform storage buffer descriptors can only support load operations while +storage buffer descriptors can support load, store, and atomic operations.) +Vulkan uses a binding model for resources. Resources are associated with +descriptors and descriptors are further grouped into sets. Each descriptor thus +has a set number and a binding number. Descriptors in the application +corresponds to variables in the SPIR-V program. Their parameters must match, +including but not limited to set and binding numbers. + +Apart from buffers and images, there is other data that is set up by Vulkan and +referenced inside the SPIR-V program, for example, push constants. They also +have parameters that require matching between the two worlds. + +The interface requirements are external information to the SPIR-V compilation +path in MLIR. Besides, each Vulkan application may want to handle resources +differently. To avoid duplication and to share common utilities, a SPIR-V shader +interface specification needs to be defined to provide the external requirements +to and guide the SPIR-V compilation path. + +### Shader interface attributes + +The SPIR-V dialect defines [a few attributes][MlirSpirvAbi] for specifying these +interfaces: + +* `spv.entry_point_abi` is a struct attribute that should be attached to the + entry function. It contains: + * `local_size` for specifying the local work group size for the dispatch. +* `spv.interface_var_abi` is a struct attribute that should be attached to + each operand and result of the entry function. It contains: + * `descriptor_set` for specifying the descriptor set number for the + corresponding resource variable. + * `binding` for specifying the binding number for the corresponding + resource variable. + * `storage_class` for specifying the storage class for the corresponding + resource variable. + +The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for +consuming these attributes and create SPIR-V module complying with the +interface. + +## Serialization and deserialization + +Although the main objective of the SPIR-V dialect is to act as a proper IR for +compiler transformations, being able to serialize to and deserialize from the +binary format is still very valuable for many good reasons. Serialization +enables the artifacts of SPIR-V compilation to be consumed by a execution +environment; deserialization allows us to import SPIR-V binary modules and run +transformations on them. So serialization and deserialization is supported from +the very beginning of the development of the SPIR-V dialect. + +The serialization library provides two entry points, `mlir::spirv::serialize()` +and `mlir::spirv::deserialize()`, for converting a MLIR SPIR-V module to binary +format and back. The [Code organization](#code-organization) explains more about +this. + +Given that the focus is transformations, which inevitably means changes to the +binary module; so serialization is not designed to be a general tool for +investigating the SPIR-V binary module and does not guarantee roundtrip +equivalence (at least for now). For the latter, please use the +assembler/disassembler in the [SPIRV-Tools][SpirvTools] project. + +A few transformations are performed in the process of serialization because of +the representational differences between SPIR-V dialect and binary format: + +* Attributes on `spv.module` are emitted as their corresponding SPIR-V + instructions. +* Types are serialized into `OpType*` instructions in the SPIR-V binary module + section for types, constants, and global variables. +* `spv.constant`s are unified and placed in the SPIR-V binary module section + for types, constants, and global variables. +* Attributes on ops, if not part of the op's binary encoding, are emitted as + `OpDecorate*` instructions in the SPIR-V binary module section for + decorations. +* `spv.selection`s and `spv.loop`s are emitted as basic blocks with `Op*Merge` + instructions in the header block as required by the binary format. +* Block arguments are materialized as `OpPhi` instructions at the beginning of + the corresponding blocks. + +Similarly, a few transformations are performed during deserialization: + +* Instructions for execution environment requirements (extensions, + capabilities, extended instruction sets, etc.) will be placed as attributes + on `spv.module`. +* `OpType*` instructions will be converted into proper `mlir::Type`s. +* `OpConstant*` instructions are materialized as `spv.constant` at each use + site. +* `OpVariable` instructions will be converted to `spv.globalVariable` ops if + in module-level; otherwise they will be converted into `spv.Variable` ops. +* Every use of a module-level `OpVariable` instruction will materialize a + `spv._address_of` op to turn the symbol of the corresponding + `spv.globalVariable` into an SSA value. +* Every use of a `OpSpecConstant` instruction will materialize a + `spv._reference_of` op to turn the symbol of the corresponding + `spv.specConstant` into an SSA value. +* `OpPhi` instructions are converted to block arguments. +* Structured control flow are placed inside `spv.selection` and `spv.loop`. + +## Conversions + +(TODO: expand this section) + +## Code organization + +We aim to provide multiple libraries with clear dependencies for SPIR-V related +functionalities in MLIR so developers can just choose the needed components +without pulling in the whole world. + +### The dialect + +The code for the SPIR-V dialect resides in a few places: + +* Public headers are placed in [include/mlir/Dialect/SPIRV][MlirSpirvHeaders]. +* Libraries are placed in [lib/Dialect/SPIRV][MlirSpirvLibs]. +* IR tests are placed in [test/Dialect/SPIRV][MlirSpirvTests]. +* Unit tests are placed in [unittests/Dialect/SPIRV][MlirSpirvUnittests]. + +The whole SPIR-V dialect is exposed via multiple headers for better +organization: + +* [SPIRVDialect.h][MlirSpirvDialect] defines the SPIR-V dialect. +* [SPIRVTypes.h][MlirSpirvTypes] defines all SPIR-V specific types. +* [SPIRVOps.h][MlirSPirvOps] defines all SPIR-V operations. +* [Serialization.h][MlirSpirvSerialization] defines the entry points for + serialization and deserialization. + +The dialect itself, including all types and ops, is in the `MLIRSPIRV` library. +Serialization functionalities are in the `MLIRSPIRVSerialization` library. + +### Op definitions + +We use [Op Definition Spec][ODS] to define all SPIR-V ops. They are written in +TableGen syntax and placed in various `*Ops.td` files in the header directory. +Those `*Ops.td` files are organized according to the instruction categories used +in the SPIR-V specification, for example, an op belonging to the "Atomics +Instructions" section is put in the `SPIRVAtomicOps.td` file. + +`SPIRVOps.td` serves as the master op definition file that includes all files +for specific categories. + +`SPIRVBase.td` defines common classes and utilities used by various op +definitions. It contains the TableGen SPIR-V dialect definition, SPIR-V +versions, known extensions, various SPIR-V enums, TableGen SPIR-V types, and +base op classes, etc. + +Many of the contents in `SPIRVBase.td`, e.g., the opcodes and various enums, and +all `*Ops.td` files can be automatically updated via a Python script, which +queries the SPIR-V specification and grammar. This greatly reduces the burden of +supporting new ops and keeping updated with the SPIR-V spec. More details on +this automated development can be found in the +[Automated development flow](#automated-development-flow) section. + +### Dialect conversions + +The code for conversions from other dialects to the SPIR-V dialect also resides +in a few places: + +* From GPU dialect: headers are at + [include/mlir/Conversion/GPUTOSPIRV][MlirGpuToSpirvHeaders]; libraries are + at [lib/Conversion/GPUToSPIRV][MlirGpuToSpirvLibs]. +* From standard dialect: headers are at + [include/mlir/Conversion/StandardTOSPIRV][MlirStdToSpirvHeaders]; libraries + are at [lib/Conversion/StandardToSPIRV][MlirStdToSpirvLibs]. + +These dialect to dialect conversions have their dedicated libraries, +`MLIRGPUToSPIRVTransforms` and `MLIRStandardToSPIRVTransforms`, respectively. + +There are also common utilities when targeting SPIR-V from any dialect: + +* [include/mlir/Dialect/SPIRV/Passes.h][MlirSpirvPasses] contains SPIR-V + specific analyses and transformations. +* [include/mlir/Dialect/SPIRV/SPIRVLowering.h][MlirSpirvLowering] contains + type converters and other utility functions. + +These common utilities are implemented in the `MLIRSPIRVTransforms` library. + +## Contribution + +All kinds of contributions are highly appreciated! :) We have GitHub issues for +tracking the [dialect][GitHubDialectTracking] and +[lowering][GitHubLoweringTracking] development. You can find todo tasks there. +The [Code organization](#code-organization) section gives an overview of how +SPIR-V related functionalities are implemented in MLIR. This section gives more +concrete steps on how to contribute. + +### Automated development flow + +One of the goals of SPIR-V dialect development is to leverage both the SPIR-V +[human-readable specification][SpirvSpec] and +[machine-readable grammar][SpirvGrammar] to auto-generate as much contents as +possible. Specifically, the following tasks can be automated (partially or +fully): + +* Adding support for a new operation. +* Adding support for a new SPIR-V enum. +* Serialization and deserialization of a new operation. + +We achieve this using the Python script +[`gen_spirv_dialect.py`][GenSpirvUtilsPy]. It fetches the human-readable +specification and machine-readable grammar directly from the Internet and +updates various SPIR-V `*.td` files in place. The script gives us an automated +flow for adding support for new ops or enums. + +Afterwards, we have SPIR-V specific `mlir-tblgen` backends for reading the Op +Definition Spec and generate various components, including (de)serialization +logic for ops. Together with standard `mlir-tblgen` backends, we auto-generate +all op classes, enum classes, etc. + +In the following subsections, we list the detailed steps to follow for common +tasks. + +### Add a new op + +To add a new op, invoke the `define_inst.sh` script wrapper in utils/spirv. +`define_inst.sh` requires a few parameters: + +```sh +./define_inst.sh +``` + +For example, to define the op for `OpIAdd`, invoke + +```sh +./define_inst.sh SPIRVArithmeticOps.td ArithmeticBinaryOp OpIAdd +``` + +where `SPIRVArithmeticOps.td` is the filename for hosting the new op and +`ArithmeticBinaryOp` is the direct base class the newly defined op will derive +from. + +Similarly, to define the op for `OpAtomicAnd`, + +```sh +./define_inst.sh SPIRVAtomicOps.td AtomicUpdateWithValueOp OpAtomicAnd +``` + +Note that the generated SPIR-V op definition is just a best-effort template; it +is still expected to be updated to have more accurate traits, arguments, and +results. + +The generated op will automatically gain the logic for (de)serialization. +However, tests still need to be coupled with the change to make sure no +surprises. Serialization tests live in test/Dialect/SPIRV/Serialization. + +### Add a new enum + +To add a new enum, invoke the `define_enum.sh` script wrapper in utils/spirv. +`define_enum.sh` expects the following parameters: + +```sh +./define_enum.sh +``` + +For example, to add the definition for SPIR-V storage class in to +`SPIRVBase.td`: + +```sh +./define_enum.sh StorageClass +``` + +### Add a new conversion + +(TODO: add details for this section) + +[Spirv]: https://www.khronos.org/registry/spir-v/ +[SpirvSpec]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html +[SpirvLogicalLayout]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module +[SpirvGrammar]: https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json +[GlslStd450]: https://www.khronos.org/registry/spir-v/specs/1.0/GLSL.std.450.html +[ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray +[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage +[PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer +[RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray +[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure +[SpirvTools]: https://github.com/KhronosGroup/SPIRV-Tools +[Rationale]: https://github.com/tensorflow/mlir/blob/master/g3doc/Rationale.md#block-arguments-vs-phi-nodes +[ODS]: https://github.com/tensorflow/mlir/blob/master/g3doc/OpDefinitions.md +[GreedyPatternRewriter]: https://github.com/tensorflow/mlir/blob/master/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +[MlirSpirvHeaders]: https://github.com/tensorflow/mlir/tree/master/include/mlir/Dialect/SPIRV +[MlirSpirvLibs]: https://github.com/tensorflow/mlir/tree/master/lib/Dialect/SPIRV +[MlirSpirvTests]: https://github.com/tensorflow/mlir/tree/master/test/Dialect/SPIRV +[MlirSpirvUnittests]: https://github.com/tensorflow/mlir/tree/master/unittests/Dialect/SPIRV +[MlirGpuToSpirvHeaders]: https://github.com/tensorflow/mlir/tree/master/include/mlir/Conversion/GPUToSPIRV +[MlirGpuToSpirvLibs]: https://github.com/tensorflow/mlir/tree/master/lib/Conversion/GPUToSPIRV +[MlirStdToSpirvHeaders]: https://github.com/tensorflow/mlir/tree/master/include/mlir/Conversion/StandardToSPIRV +[MlirStdToSpirvLibs]: https://github.com/tensorflow/mlir/tree/master/lib/Conversion/StandardToSPIRV +[MlirSpirvDialect]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVDialect.h +[MlirSpirvTypes]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVTypes.h +[MlirSpirvOps]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVOps.h +[MlirSpirvSerialization]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/Serialization.h +[MlirSpirvBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVBase.td +[MlirSpirvPasses]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/Passes.h +[MlirSpirvLowering]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVLowering.h +[MlirSpirvAbi]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVLowering.td +[GitHubDialectTracking]: https://github.com/tensorflow/mlir/issues/302 +[GitHubLoweringTracking]: https://github.com/tensorflow/mlir/issues/303 +[GenSpirvUtilsPy]: https://github.com/tensorflow/mlir/blob/master/utils/spirv/gen_spirv_dialect.py diff --git a/mlir/docs/Dialects/Standard.md b/mlir/docs/Dialects/Standard.md new file mode 100644 index 0000000000000000000000000000000000000000..f84a2c94e921ed463b5730053bb7b89a89bba012 --- /dev/null +++ b/mlir/docs/Dialects/Standard.md @@ -0,0 +1,1146 @@ +# Standard Dialect + +This dialect provides documentation for operations within the Standard dialect. + +Note: This dialect is a collection of operations for several different concepts, +and should be split into multiple more-focused dialects accordingly. + +[TOC] + +TODO: shape, which returns a 1D tensor, and can take an unknown rank tensor as +input. + +TODO: rank, which returns an index. + +## Terminator operations + +Terminator operations are required at the end of each block. They may contain a +list of successors, i.e. other blocks to which the control flow will proceed. + +### 'br' terminator operation + +Syntax: + +``` +operation ::= `br` successor +successor ::= bb-id branch-use-list? +branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` +``` + +The `br` terminator operation represents an unconditional jump to a target +block. The count and types of operands to the branch must align with the +arguments in the target block. + +The MLIR branch operation is not allowed to target the entry block for a region. + +### 'cond_br' terminator operation + +Syntax: + +``` +operation ::= `cond_br` ssa-use `,` successor `,` successor +``` + +The `cond_br` terminator operation represents a conditional branch on a boolean +(1-bit integer) value. If the bit is set, then the first destination is jumped +to; if it is false, the second destination is chosen. The count and types of +operands must align with the arguments in the corresponding target blocks. + +The MLIR conditional branch operation is not allowed to target the entry block +for a region. The two destinations of the conditional branch operation are +allowed to be the same. + +The following example illustrates a function with a conditional branch operation +that targets the same block: + +```mlir +func @select(i32, i32, i1) -> i32 { +^bb0(%a : i32, %b :i32, %flag : i1) : + // Both targets are the same, operands differ + cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32) + +^bb1(%x : i32) : + return %x : i32 +} +``` + +### 'return' terminator operation + +Syntax: + +``` +operation ::= `return` (ssa-use-list `:` type-list-no-parens)? +``` + +The `return` terminator operation represents the completion of a function, and +produces the result values. The count and types of the operands must match the +result types of the enclosing function. It is legal for multiple blocks in a +single function to return. + +## Core Operations + +### 'call' operation + +Syntax: + +``` +operation ::= + (ssa-id `=`)? `call` symbol-ref-id `(` ssa-use-list? `)` `:` function-type +``` + +The `call` operation represents a direct call to a function. The operands and +result types of the call must match the specified function type. The callee is +encoded as a function attribute named "callee". + +Example: + +```mlir +// Calling the function my_add. +%31 = call @my_add(%0, %1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> +``` + +### 'call_indirect' operation + +Syntax: + +``` +operation ::= `call_indirect` ssa-use `(` ssa-use-list? `)` `:` function-type +``` + +The `call_indirect` operation represents an indirect call to a value of function +type. Functions are first class types in MLIR, and may be passed as arguments +and merged together with block arguments. The operands and result types of the +call must match the specified function type. + +Function values can be created with the +[`constant` operation](#constant-operation). + +Example: + +```mlir +%31 = call_indirect %15(%0, %1) + : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> +``` + +### 'dim' operation + +Syntax: + +``` +operation ::= ssa-id `=` `dim` ssa-id `,` integer-literal `:` type +``` + +The `dim` operation takes a memref or tensor operand and a dimension index, and +returns an [`index`](../LangRef.md#index-type) that is the size of that +dimension. + +The `dim` operation is represented with a single integer attribute named +`index`, and the type specifies the type of the memref or tensor operand. + +Examples: + +```mlir +// Always returns 4, can be constant folded: +%x = dim %A, 0 : tensor<4 x ? x f32> + +// Returns the dynamic dimension of %A. +%y = dim %A, 1 : tensor<4 x ? x f32> + +// Equivalent generic form: +%x = "std.dim"(%A) {index = 0 : i64} : (tensor<4 x ? x f32>) -> index +%y = "std.dim"(%A) {index = 1 : i64} : (tensor<4 x ? x f32>) -> index +``` + +## Memory Operations + +### 'alloc' operation + +Syntax: + +``` +operation ::= ssa-id `=` `alloc` dim-and-symbol-use-list `:` memref-type +``` + +Allocates a new memref of specified type. Values required for dynamic dimension +sizes are passed as arguments in parentheses (in the same order in which they +appear in the shape signature of the memref) while the symbols required by the +layout map are passed in the square brackets in lexicographical order. If no +layout maps are specified in the memref, then an identity mapping is used. + +The buffer referenced by a memref type is created by the `alloc` operation, and +destroyed by the `dealloc` operation. + +Example: + +```mlir +// Allocating memref for a fully static shape. +%A = alloc() : memref<1024x64xf32, #layout_map0, memspace0> + +// %M, %N, %x, %y are SSA values of integer type. M and N are bound to the +// two unknown dimensions of the type and x/y are bound to symbols in +// #layout_map1. +%B = alloc(%M, %N)[%x, %y] : memref +``` + +### 'alloc_static' operation + +Syntax: + +``` +operation ::= + ssa-id `=` `alloc_static` `(` integer-literal `)` : memref-type +``` + +Allocates a new memref of specified type with a fixed base pointer location in +memory. 'alloc_static' does not support types that have dynamic shapes or that +require dynamic symbols in their layout function (use the +[`alloc` operation](#alloc-operation) in those cases). + +Example: + +```mlir +%A = alloc_static(0x1232a00) : memref<1024 x 64 x f32, #layout_map0, memspace0> +``` + +The `alloc_static` operation is used to represent code after buffer allocation +has been performed. + +### 'dealloc' operation + +Syntax: + +``` +operation ::= `dealloc` ssa-use `:` memref-type +``` + +Delineates the end of the lifetime of the memory corresponding to a memref +allocation. It is paired with an [`alloc`](#alloc-operation) or +[`alloc_static`](#alloc-static-operation) operation. + +Example: + +```mlir +dealloc %A : memref<128 x f32, #layout, memspace0> +``` + +### 'dma_start' operation + +Syntax: + +``` +operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,` + ssa-use`[`ssa-use-list`]` `,` ssa-use `,` + ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)? + `:` memref-type `,` memref-type `,` memref-type +``` + +Starts a non-blocking DMA operation that transfers data from a source memref to +a destination memref. The operands include the source and destination memref's +each followed by its indices, size of the data transfer in terms of the number +of elements (of the elemental type of the memref), a tag memref with its +indices, and optionally two additional arguments corresponding to the stride (in +terms of number of elements) and the number of elements to transfer per stride. +The tag location is used by a dma_wait operation to check for completion. The +indices of the source memref, destination memref, and the tag memref have the +same restrictions as any load/store operation in a affine context (whenever DMA +operations appear in an affine context). See +[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) +in affine contexts. This allows powerful static analysis and transformations in +the presence of such DMAs including rescheduling, pipelining / overlap with +computation, and checking for matching start/end operations. The source and +destination memref need not be of the same dimensionality, but need to have the +same elemental type. + +For example, a `dma_start` operation that transfers 32 vector elements from a +memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be +specified as shown below. + +Example: + +```mlir +%size = constant 32 : index +%tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> +%idx = constant 0 : index +dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] : + memref<40 x 8 x vector<16xf32>, (d0, d1) -> (d0, d1), 0>, + memref<2 x 4 x vector<16xf32>, (d0, d1) -> (d0, d1), 2>, + memref<1 x i32>, (d0) -> (d0), 4> +``` + +### 'dma_wait' operation + +Syntax: + +``` +operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type +``` + +Blocks until the completion of a DMA operation associated with the tag element +specified with a tag memref and its indices. The operands include the tag memref +followed by its indices and the number of elements associated with the DMA being +waited on. The indices of the tag memref have the same restrictions as +load/store indices. + +Example: + +```mlir +dma_wait %tag[%idx], %size : memref<1 x i32, (d0) -> (d0), 4> +``` + +### 'extract_element' operation + +Syntax: + +``` +operation ::= ssa-id `=` `extract_element` ssa-use `[` ssa-use-list `]` `:` type +``` + +The `extract_element` op reads a tensor or vector and returns one element from +it specified by an index list. The output of the 'extract_element' is a new +value with the same type as the elements of the tensor or vector. The arity of +indices matches the rank of the accessed value (i.e., if a tensor is of rank 3, +then 3 indices are required for the extract. The indices should all be of +`index` type. + +Examples: + +```mlir +%3 = extract_element %v[%1, %2] : vector<4x4xi32> +%4 = extract_element %t[%1, %2] : tensor<4x4xi32> +%5 = extract_element %ut[%1, %2] : tensor<*xi32> +``` + +### 'load' operation + +Syntax: + +``` +operation ::= ssa-id `=` `load` ssa-use `[` ssa-use-list `]` `:` memref-type +``` + +The `load` op reads an element from a memref specified by an index list. The +output of load is a new value with the same type as the elements of the memref. +The arity of indices is the rank of the memref (i.e., if the memref loaded from +is of rank 3, then 3 indices are required for the load following the memref +identifier). + +In an `affine.if` or `affine.for` body, the indices of a load are restricted to +SSA values bound to surrounding loop induction variables, +[symbols](../LangRef.md#dimensions-and-symbols), results of a +[`constant` operation](#constant-operation), or the result of an `affine.apply` +operation that can in turn take as arguments all of the aforementioned SSA +values or the recursively result of such an `affine.apply` operation. + +Example: + +```mlir +%1 = affine.apply (d0, d1) -> (3*d0) (%i, %j) +%2 = affine.apply (d0, d1) -> (d1+1) (%i, %j) +%12 = load %A[%1, %2] : memref<8x?xi32, #layout, memspace0> + +// Example of an indirect load (treated as non-affine) +%3 = affine.apply (d0) -> (2*d0 + 1)(%12) +%13 = load %A[%3, %2] : memref<4x?xi32, #layout, memspace0> +``` + +**Context:** The `load` and `store` operations are specifically crafted to fully +resolve a reference to an element of a memref, and (in affine `affine.if` and +`affine.for` operations) the compiler can follow use-def chains (e.g. through +[`affine.apply`](Affine.md#affineapply-operation) operations) to precisely +analyze references at compile-time using polyhedral techniques. This is possible +because of the +[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) +in these contexts. + +### 'splat' operation + +Syntax: + +``` +operation ::= ssa-id `=` `splat` ssa-use `:` ( vector-type | tensor-type ) +``` + +Broadcast the operand to all elements of the result vector or tensor. The +operand has to be of either integer or float type. When the result is a tensor, +it has to be statically shaped. + +Example: + +```mlir + %s = load %A[%i] : memref<128xf32> + %v = splat %s : vector<4xf32> + %t = splat %s : tensor<8x16xi32> +``` + +TODO: This operation is easy to extend to broadcast to dynamically shaped +tensors in the same way dynamically shaped memrefs are handled. +```mlir +// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding +// to the sizes of the two dynamic dimensions. +%m = "foo"() : () -> (index) +%n = "bar"() : () -> (index) +%t = splat %s [%m, %n] : tensor +``` + +### 'store' operation + +Syntax: + +``` +operation ::= `store` ssa-use `,` ssa-use `[` ssa-use-list `]` `:` memref-type +``` + +Store value to memref location given by indices. The value stored should have +the same type as the elemental type of the memref. The number of arguments +provided within brackets need to match the rank of the memref. + +In an affine context, the indices of a store are restricted to SSA values bound +to surrounding loop induction variables, +[symbols](Affine.md#restrictions-on-dimensions-and-symbols), results of a +[`constant` operation](#constant-operation), or the result of an +[`affine.apply`](Affine.md#affineapply-operation) operation that can in turn +take as arguments all of the aforementioned SSA values or the recursively result +of such an `affine.apply` operation. + +Example: + +```mlir +store %100, %A[%1, 1023] : memref<4x?xf32, #layout, memspace0> +``` + +**Context:** The `load` and `store` operations are specifically crafted to fully +resolve a reference to an element of a memref, and (in polyhedral `affine.if` +and `affine.for` operations) the compiler can follow use-def chains (e.g. +through [`affine.apply`](Affine.md#affineapply-operation) operations) to +precisely analyze references at compile-time using polyhedral techniques. This +is possible because of the +[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) +in these contexts. + +### 'tensor_load' operation + +Syntax: + +``` +operation ::= ssa-id `=` `tensor_load` ssa-use-and-type +``` + +Create a tensor from a memref, making an independent copy of the element data. +The result value is a tensor whose shape and element type match the memref +operand. + +Example: + +```mlir +// Produces a value of tensor<4x?xf32> type. +%12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0> +``` + +### 'tensor_store' operation + +Syntax: + +``` +operation ::= `tensor_store` ssa-use `,` ssa-use `:` memref-type +``` + +Stores the contents of a tensor into a memref. The first operand is a value of +tensor type, the second operand is a value of memref type. The shapes and +element types of these must match, and are specified by the memref type. + +Example: + +```mlir +%9 = dim %8, 1 : tensor<4x?xf32> +%10 = alloc(%9) : memref<4x?xf32, #layout, memspace0> +tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0> +``` + +## Unary Operations + +### 'absf' operation + +Syntax: + +``` +operation ::= ssa-id `=` `absf` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar absolute value. +%a = absf %b : f64 + +// SIMD vector element-wise absolute value. +%f = absf %g : vector<4xf32> + +// Tensor element-wise absolute value. +%x = absf %y : tensor<4x?xf8> +``` + +The `absf` operation computes the absolute value. It takes one operand and +returns one result of the same type. This type may be a float scalar type, a +vector whose element type is float, or a tensor of floats. It has no standard +attributes. + +### 'ceilf' operation + +Syntax: + +``` +operation ::= ssa-id `=` `ceilf` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar ceiling value. +%a = ceilf %b : f64 + +// SIMD vector element-wise ceiling value. +%f = ceilf %g : vector<4xf32> + +// Tensor element-wise ceiling value. +%x = ceilf %y : tensor<4x?xf8> +``` + +The `ceilf` operation computes the ceiling of a given value. It takes one +operand and returns one result of the same type. This type may be a float +scalar type, a vector whose element type is float, or a tensor of floats. It +has no standard attributes. + +### 'cos' operation + +Syntax: + +``` +operation ::= ssa-id `=` `cos` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar cosine value. +%a = cos %b : f64 + +// SIMD vector element-wise cosine value. +%f = cos %g : vector<4xf32> + +// Tensor element-wise cosine value. +%x = cos %y : tensor<4x?xf8> +``` + +The `cos` operation computes the cosine of a given value. It takes one operand +and returns one result of the same type. This type may be a float scalar type, +a vector whose element type is float, or a tensor of floats. It has no standard +attributes. + +### 'exp' operation + +Syntax: + +``` +operation ::= ssa-id `=` `exp` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar natural exponential. +%a = exp %b : f64 + +// SIMD vector element-wise natural exponential. +%f = exp %g : vector<4xf32> + +// Tensor element-wise natural exponential. +%x = exp %y : tensor<4x?xf8> +``` + +The `exp` operation takes one operand and returns one result of the same type. +This type may be a float scalar type, a vector whose element type is float, or a +tensor of floats. It has no standard attributes. + +### 'negf' operation + +Syntax: + +``` +operation ::= ssa-id `=` `negf` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar negation value. +%a = negf %b : f64 + +// SIMD vector element-wise negation value. +%f = negf %g : vector<4xf32> + +// Tensor element-wise negation value. +%x = negf %y : tensor<4x?xf8> +``` + +The `negf` operation computes the negation of a given value. It takes one +operand and returns one result of the same type. This type may be a float +scalar type, a vector whose element type is float, or a tensor of floats. It +has no standard attributes. + +### 'tanh' operation + +Syntax: + +``` +operation ::= ssa-id `=` `tanh` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar hyperbolic tangent value. +%a = tanh %b : f64 + +// SIMD vector element-wise hyperbolic tangent value. +%f = tanh %g : vector<4xf32> + +// Tensor element-wise hyperbolic tangent value. +%x = tanh %y : tensor<4x?xf8> +``` + +The `tanh` operation computes the hyperbolic tangent. It takes one operand and +returns one result of the same type. This type may be a float scalar type, a +vector whose element type is float, or a tensor of floats. It has no standard +attributes. + +## Arithmetic Operations + +Basic arithmetic in MLIR is specified by standard operations described in this +section. + +### 'addi' operation + +Syntax: + +``` +operation ::= ssa-id `=` `addi` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar addition. +%a = addi %b, %c : i64 + +// SIMD vector element-wise addition, e.g. for Intel SSE. +%f = addi %g, %h : vector<4xi32> + +// Tensor element-wise addition. +%x = addi %y, %z : tensor<4x?xi8> +``` + +The `addi` operation takes two operands and returns one result, each of these is +required to be the same type. This type may be an integer scalar type, a vector +whose element type is integer, or a tensor of integers. It has no standard +attributes. + +### 'addf' operation + +Syntax: + +``` +operation ::= ssa-id `=` `addf` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar addition. +%a = addf %b, %c : f64 + +// SIMD vector addition, e.g. for Intel SSE. +%f = addf %g, %h : vector<4xf32> + +// Tensor addition. +%x = addf %y, %z : tensor<4x?xbf16> +``` + +The `addf` operation takes two operands and returns one result, each of these is +required to be the same type. This type may be a floating point scalar type, a +vector whose element type is a floating point type, or a floating point tensor. + +It has no standard attributes. + +TODO: In the distant future, this will accept optional attributes for fast math, +contraction, rounding mode, and other controls. + +### 'and' operation + +Bitwise integer and. + +Syntax: + +``` +operation ::= ssa-id `=` `and` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar integer bitwise and. +%a = and %b, %c : i64 + +// SIMD vector element-wise bitwise integer and. +%f = and %g, %h : vector<4xi32> + +// Tensor element-wise bitwise integer and. +%x = and %y, %z : tensor<4x?xi8> +``` + +The `and` operation takes two operands and returns one result, each of these is +required to be the same type. This type may be an integer scalar type, a vector +whose element type is integer, or a tensor of integers. It has no standard +attributes. + +### 'cmpi' operation + +Syntax: + +``` +operation ::= ssa-id `=` `cmpi` string-literal `,` ssa-id `,` ssa-id `:` type +``` + +Examples: + +```mlir +// Custom form of scalar "signed less than" comparison. +%x = cmpi "slt", %lhs, %rhs : i32 + +// Generic form of the same operation. +%x = "std.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 + +// Custom form of vector equality comparison. +%x = cmpi "eq", %lhs, %rhs : vector<4xi64> + +// Generic form of the same operation. +%x = "std.cmpi"(%lhs, %rhs) {predicate = 0 : i64} + : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> +``` + +The `cmpi` operation is a generic comparison for integer-like types. Its two +arguments can be integers, vectors or tensors thereof as long as their types +match. The operation produces an i1 for the former case, a vector or a tensor of +i1 with the same shape as inputs in the other cases. + +Its first argument is an attribute that defines which type of comparison is +performed. The following comparisons are supported: + +- equal (mnemonic: `"eq"`; integer value: `0`) +- not equal (mnemonic: `"ne"`; integer value: `1`) +- signed less than (mnemonic: `"slt"`; integer value: `2`) +- signed less than or equal (mnemonic: `"sle"`; integer value: `3`) +- signed greater than (mnemonic: `"sgt"`; integer value: `4`) +- signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) +- unsigned less than (mnemonic: `"ult"`; integer value: `6`) +- unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) +- unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) +- unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) + +The result is `1` if the comparison is true and `0` otherwise. For vector or +tensor operands, the comparison is performed elementwise and the element of the +result indicates whether the comparison is true for the operand elements with +the same indices as those of the result. + +Note: while the custom assembly form uses strings, the actual underlying +attribute has integer type (or rather enum class in C++ code) as seen from the +generic assembly form. String literals are used to improve readability of the IR +by humans. + +This operation only applies to integer-like operands, but not floats. The main +reason being that comparison operations have diverging sets of attributes: +integers require sign specification while floats require various floating +point-related particularities, e.g., `-ffast-math` behavior, IEEE754 compliance, +etc +([rationale](../Rationale.md#splitting-floating-point-vs-integer-operations)). +The type of comparison is specified as attribute to avoid introducing ten +similar operations, taking into account that they are often implemented using +the same operation downstream +([rationale](../Rationale.md#specifying-comparison-kind-as-attribute)). The +separation between signed and unsigned order comparisons is necessary because of +integers being signless. The comparison operation must know how to interpret +values with the foremost bit being set: negatives in two's complement or large +positives +([rationale](../Rationale.md#specifying-sign-in-integer-comparison-operations)). + +### 'constant' operation + +Syntax: + +``` +operation ::= ssa-id `=` `constant` attribute-value `:` type +``` + +The `constant` operation produces an SSA value equal to some constant specified +by an attribute. This is the way that MLIR uses to form simple integer and +floating point constants, as well as more exotic things like references to +functions and (TODO!) tensor/vector constants. + +The `constant` operation is represented with a single attribute named "value". +The type specifies the result type of the operation. + +Examples: + +```mlir +// Integer constant +%1 = constant 42 : i32 + +// Reference to function @myfn. +%3 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> + +// Equivalent generic forms +%1 = "std.constant"() {value = 42 : i32} : () -> i32 +%3 = "std.constant"() {value = @myfn} + : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) + +``` + +MLIR does not allow direct references to functions in SSA operands because the +compiler is multithreaded, and disallowing SSA values to directly reference a +function simplifies this +([rationale](../Rationale.md#multithreading-the-compiler)). + +### 'copysign' operation + +Syntax: + +``` +operation ::= ssa-id `=` `copysign` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar copysign value. +%a = copysign %b %c : f64 + +// SIMD vector element-wise copysign value. +%f = copysign %g %h : vector<4xf32> + +// Tensor element-wise copysign value. +%x = copysign %y %z : tensor<4x?xf8> +``` + +The `copysign` returns a value with the magnitude of the first operand and the +sign of the second operand. It takes two operands and returns one result of the +same type. This type may be a float scalar type, a vector whose element type is +float, or a tensor of floats. It has no standard attributes. + +### 'divis' operation + +Signed integer division. Rounds towards zero. Treats the leading bit as sign, +i.e. `6 / -2 = -3`. + +Note: the semantics of division by zero or signed division overflow (minimum +value divided by -1) is TBD; do NOT assume any specific behavior. + +Syntax: + +``` +operation ::= ssa-id `=` `divis` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar signed integer division. +%a = divis %b, %c : i64 + +// SIMD vector element-wise division. +%f = divis %g, %h : vector<4xi32> + +// Tensor element-wise integer division. +%x = divis %y, %z : tensor<4x?xi8> +``` + +The `divis` operation takes two operands and returns one result, each of these +is required to be the same type. This type may be an integer scalar type, a +vector whose element type is integer, or a tensor of integers. It has no +standard attributes. + +### 'diviu' operation + +Unsigned integer division. Rounds towards zero. Treats the leading bit as the +most significant, i.e. for `i16` given two's complement representation, `6 / +-2 = 6 / (2^16 - 2) = 0`. + +Note: the semantics of division by zero is TBD; do NOT assume any specific +behavior. + +Syntax: + +``` +operation ::= ssa-id `=` `diviu` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar unsigned integer division. +%a = diviu %b, %c : i64 + +// SIMD vector element-wise division. +%f = diviu %g, %h : vector<4xi32> + +// Tensor element-wise integer division. +%x = diviu %y, %z : tensor<4x?xi8> +``` + +The `diviu` operation takes two operands and returns one result, each of these +is required to be the same type. This type may be an integer scalar type, a +vector whose element type is integer, or a tensor of integers. It has no +standard attributes. + +### 'memref_cast' operation + +Syntax: + +``` +operation ::= ssa-id `=` `memref_cast` ssa-use `:` type `to` type +``` + +Examples: + +```mlir +// Discard static dimension information. +%3 = memref_cast %2 : memref<4x?xf32> to memref + +// Convert to a type with more known dimensions. +%4 = memref_cast %3 : memref to memref<4x?xf32> + +// Convert to a type with unknown rank. +%5 = memref_cast %3 : memref to memref<*xf32> + +// Convert to a type with static rank. +%6 = memref_cast %5 : memref<*xf32> to memref +``` + +Convert a memref from one type to an equivalent type without changing any data +elements. The types are equivalent if 1. they both have the same static rank, +same element type, same mappings, same address space. The operation is invalid +if converting to a mismatching constant dimension, or 2. exactly one of the +operands have an unknown rank, and they both have the same element type and same +address space. The operation is invalid if both operands are of dynamic rank or +if converting to a mismatching static rank. + +### 'mulf' operation + +Syntax: + +``` +operation ::= ssa-id `=` `mulf` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar multiplication. +%a = mulf %b, %c : f64 + +// SIMD pointwise vector multiplication, e.g. for Intel SSE. +%f = mulf %g, %h : vector<4xf32> + +// Tensor pointwise multiplication. +%x = mulf %y, %z : tensor<4x?xbf16> +``` + +The `mulf` operation takes two operands and returns one result, each of these is +required to be the same type. This type may be a floating point scalar type, a +vector whose element type is a floating point type, or a floating point tensor. + +It has no standard attributes. + +TODO: In the distant future, this will accept optional attributes for fast math, +contraction, rounding mode, and other controls. + +### 'or' operation + +Bitwise integer or. + +Syntax: + +``` +operation ::= ssa-id `=` `or` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar integer bitwise or. +%a = or %b, %c : i64 + +// SIMD vector element-wise bitwise integer or. +%f = or %g, %h : vector<4xi32> + +// Tensor element-wise bitwise integer or. +%x = or %y, %z : tensor<4x?xi8> +``` + +The `or` operation takes two operands and returns one result, each of these is +required to be the same type. This type may be an integer scalar type, a vector +whose element type is integer, or a tensor of integers. It has no standard +attributes. + +### 'remis' operation + +Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % +-2 = 0`. + +Note: the semantics of division by zero is TBD; do NOT assume any specific +behavior. + +Syntax: + +``` +operation ::= ssa-id `=` `remis` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar signed integer division remainder. +%a = remis %b, %c : i64 + +// SIMD vector element-wise division remainder. +%f = remis %g, %h : vector<4xi32> + +// Tensor element-wise integer division remainder. +%x = remis %y, %z : tensor<4x?xi8> +``` + +The `remis` operation takes two operands and returns one result, each of these +is required to be the same type. This type may be an integer scalar type, a +vector whose element type is integer, or a tensor of integers. It has no +standard attributes. + +### 'remiu' operation + +Unsigned integer division remainder. Treats the leading bit as the most +significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. + +Note: the semantics of division by zero is TBD; do NOT assume any specific +behavior. + +Syntax: + +``` +operation ::= ssa-id `=` `remiu` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar unsigned integer division remainder. +%a = remiu %b, %c : i64 + +// SIMD vector element-wise division remainder. +%f = remiu %g, %h : vector<4xi32> + +// Tensor element-wise integer division remainder. +%x = remiu %y, %z : tensor<4x?xi8> +``` + +The `remiu` operation takes two operands and returns one result, each of these +is required to be the same type. This type may be an integer scalar type, a +vector whose element type is integer, or a tensor of integers. It has no +standard attributes. + +### 'select' operation + +Syntax: + +``` +operation ::= ssa-id `=` `select` ssa-use `,` ssa-use `,` ssa-use `:` type +``` + +Examples: + +```mlir +// Custom form of scalar selection. +%x = select %cond, %true, %false : i32 + +// Generic form of the same operation. +%x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32 + +// Vector selection is element-wise +%vx = "std.select"(%vcond, %vtrue, %vfalse) + : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32> +``` + +The `select` operation chooses one value based on a binary condition supplied as +its first operand. If the value of the first operand is `1`, the second operand +is chosen, otherwise the third operand is chosen. The second and the third +operand must have the same type. + +The operation applies to vectors and tensors elementwise given the _shape_ of +all operands is identical. The choice is made for each element individually +based on the value at the same position as the element in the condition operand. + +The `select` operation combined with [`cmpi`](#cmpi-operation) can be used to +implement `min` and `max` with signed or unsigned comparison semantics. + +### 'tensor_cast' operation + +Syntax: + +``` +operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type +``` + +Examples: + +```mlir +// Convert from unknown rank to rank 2 with unknown dimension sizes. +%2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor +%2 = tensor_cast %1 : tensor<*xf32> to tensor + +// Convert to a type with more known dimensions. +%3 = "std.tensor_cast"(%2) : (tensor) -> tensor<4x?xf32> + +// Discard static dimension and rank information. +%4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor +%5 = "std.tensor_cast"(%4) : (tensor) -> tensor<*xf32> +``` + +Convert a tensor from one type to an equivalent type without changing any data +elements. The source and destination types must both be tensor types with the +same element type. If both are ranked, then the rank should be the same and +static dimensions should match. The operation is invalid if converting to a +mismatching constant dimension. + +### 'xor' operation + +Bitwise integer xor. + +Syntax: + +``` +operation ::= ssa-id `=` `xor` ssa-use, ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar integer bitwise xor. +%a = xor %b, %c : i64 + +// SIMD vector element-wise bitwise integer xor. +%f = xor %g, %h : vector<4xi32> + +// Tensor element-wise bitwise integer xor. +%x = xor %y, %z : tensor<4x?xi8> +``` + +The `xor` operation takes two operands and returns one result, each of these is +required to be the same type. This type may be an integer scalar type, a vector +whose element type is integer, or a tensor of integers. It has no standard +attributes. diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md new file mode 100644 index 0000000000000000000000000000000000000000..04f5ba71cdbd05c83f60573980335ca82340de3d --- /dev/null +++ b/mlir/docs/Dialects/Vector.md @@ -0,0 +1,14 @@ +# Vector Dialect + +This dialect provides mid-level abstraction for the MLIR super-vectorizer. + +[TOC] + +## Operations + +# To see op documentation + +```sh +mlir-tblgen --gen-op-doc -I /path/to/mlir/include \ +/path/to/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +``` diff --git a/mlir/docs/EDSC.md b/mlir/docs/EDSC.md new file mode 100644 index 0000000000000000000000000000000000000000..eaaeb6c7009bc03ded57904f5ff83cf0f8115ce5 --- /dev/null +++ b/mlir/docs/EDSC.md @@ -0,0 +1,132 @@ +# Background: declarative builders API + +The main purpose of the declarative builders API is to provide an intuitive way +of constructing MLIR programmatically. In the majority of cases, the IR we wish +to construct exhibits structured control-flow. Declarative builders provide an +API to make MLIR construction and manipulation very idiomatic, for the +structured control-flow case, in C++. + +## ScopedContext + +`mlir::edsc::ScopedContext` provides an implicit thread-local context, +supporting a simple declarative API with globally accessible builders. These +declarative builders are available within the lifetime of a `ScopedContext`. + +## ValueHandle and IndexHandle + +`mlir::edsc::ValueHandle` and `mlir::edsc::IndexHandle` provide typed +abstractions around an `mlir::Value`. These abstractions are "delayed", in the +sense that they allow separating declaration from definition. They may capture +IR snippets, as they are built, for programmatic manipulation. Intuitive +operators are provided to allow concise and idiomatic expressions. + +```c++ +ValueHandle zero = constant_index(0); +IndexHandle i, j, k; +``` + +## Intrinsics + +`mlir::edsc::ValueBuilder` is a generic wrapper for the `mlir::Builder::create` +method that operates on `ValueHandle` objects and return a single ValueHandle. +For instructions that return no values or that return multiple values, the +`mlir::edsc::InstructionBuilder` can be used. Named intrinsics are provided as +syntactic sugar to further reduce boilerplate. + +```c++ +using load = ValueBuilder; +using store = InstructionBuilder; +``` + +## LoopBuilder and AffineLoopNestBuilder + +`mlir::edsc::AffineLoopNestBuilder` provides an interface to allow writing +concise and structured loop nests. + +```c++ + ScopedContext scope(f.get()); + ValueHandle i(indexType), + j(indexType), + lb(f->getArgument(0)), + ub(f->getArgument(1)); + ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type)), + f13(constant_float(llvm::APFloat(13.0f), f32Type)), + i7(constant_int(7, 32)), + i13(constant_int(13, 32)); + AffineLoopNestBuilder(&i, lb, ub, 3)([&]{ + lb * index_t(3) + ub; + lb + index_t(3); + AffineLoopNestBuilder(&j, lb, ub, 2)([&]{ + ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)), + index_t(32)); + ((f7 + f13) / f7) % f13 - f7 * f13; + ((i7 + i13) / i7) % i13 - i7 * i13; + }); + }); +``` + +## IndexedValue + +`mlir::edsc::IndexedValue` provides an index notation around load and store +operations on abstract data types by overloading the C++ assignment and +parenthesis operators. The relevant loads and stores are emitted as appropriate. + +## Putting it all together + +With declarative builders, it becomes fairly concise to build rank and +type-agnostic custom operations even though MLIR does not yet have generic +types. Here is what a definition of a general pointwise add looks in +Tablegen with declarative builders. + +```c++ +def AddOp : Op<"x.add">, + Arguments<(ins Tensor:$A, Tensor:$B)>, + Results<(outs Tensor: $C)> { + code referenceImplementation = [{ + auto ivs = makeIndexHandles(view_A.rank()); + auto pivs = makePIndexHandles(ivs); + IndexedValue A(arg_A), B(arg_B), C(arg_C); + AffineLoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())( + [&]{ + C(ivs) = A(ivs) + B(ivs) + }); + }]; +} +``` + +Depending on the function signature on which this emitter is called, the +generated IR resembles the following, for a 4-D memref of `vector<4xi8>`: + +``` +// CHECK-LABEL: func @t1(%lhs: memref<3x4x5x6xvector<4xi8>>, %rhs: memref<3x4x5x6xvector<4xi8>>, %result: memref<3x4x5x6xvector<4xi8>>) -> () { +// CHECK: affine.for {{.*}} = 0 to 3 { +// CHECK: affine.for {{.*}} = 0 to 4 { +// CHECK: affine.for {{.*}} = 0 to 5 { +// CHECK: affine.for {{.*}}= 0 to 6 { +// CHECK: {{.*}} = load %arg1[{{.*}}] : memref<3x4x5x6xvector<4xi8>> +// CHECK: {{.*}} = load %arg0[{{.*}}] : memref<3x4x5x6xvector<4xi8>> +// CHECK: {{.*}} = addi {{.*}} : vector<4xi8> +// CHECK: store {{.*}}, %arg2[{{.*}}] : memref<3x4x5x6xvector<4xi8>> +``` + +or the following, for a 0-D `memref`: + +``` +// CHECK-LABEL: func @t3(%lhs: memref, %rhs: memref, %result: memref) -> () { +// CHECK: {{.*}} = load %arg1[] : memref +// CHECK: {{.*}} = load %arg0[] : memref +// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 +// CHECK: store {{.*}}, %arg2[] : memref +``` + +Similar APIs are provided to emit the lower-level `loop.for` op with +`LoopNestBuilder`. See the `builder-api-test.cpp` test for more usage examples. + +Since the implementation of declarative builders is in C++, it is also available +to program the IR with an embedded-DSL flavor directly integrated in MLIR. We +make use of these properties in the tutorial. + +Spoiler: MLIR also provides Python bindings for these builders, and a +full-fledged Python machine learning DSL with automatic differentiation +targeting MLIR was built as an early research collaboration. + diff --git a/mlir/docs/GenericDAGRewriter.md b/mlir/docs/GenericDAGRewriter.md new file mode 100644 index 0000000000000000000000000000000000000000..8cc09f7d17ffdf5a1d186d0ade830421f2daad46 --- /dev/null +++ b/mlir/docs/GenericDAGRewriter.md @@ -0,0 +1,415 @@ +# MLIR Generic DAG Rewriter Infrastructure + +## Introduction and Motivation + +The goal of a compiler IR is to represent code - at various levels of +abstraction which pose different sets of tradeoffs in terms of representational +capabilities and ease of transformation. However, the ability to represent code +is not itself very useful - you also need to be able to implement those +transformations. + +There are many different sorts of compiler transformations, but this document +focuses on a particularly important class of transformation that comes up +repeatedly at scale, and is important for the immediate goals of MLIR: that of +pattern matching on a set of operations and replacing with another set. This is +the key algorithm required to implement the "op fission" algorithm used by the +tf2xla bridge, pattern matching rewrites from TF ops to TF/Lite, peephole +optimizations like "eliminate identity nodes" or "replace x+0 with x", as well +as a useful abstraction to implement optimization algorithms for MLIR graphs at +all levels. + +A particular strength of MLIR (and a major difference vs other compiler +infrastructures like LLVM, GCC, XLA, TensorFlow, etc) is that it uses a single +compiler IR to represent code at multiple levels of abstraction: an MLIR +operation can be a "TensorFlow operation", an "XLA HLO", a "TF Lite +FlatBufferModel op", a TPU LLO instruction, an LLVM IR instruction (transitively +including X86, Lanai, CUDA, and other target specific instructions), or anything +else that the MLIR type system can reasonably express. Because MLIR spans such a +wide range of different problems, a single infrastructure for performing +graph-to-graph rewrites can help solve many diverse domain challenges, including +TensorFlow graph level down to the machine code level. + +[Static single assignment](https://en.wikipedia.org/wiki/Static_single_assignment_form) +(SSA) representations like MLIR make it easy to access the operands and "users" +of an operation. As such, a natural abstraction for these graph-to-graph +rewrites is that of DAG pattern matching: clients define DAG tile patterns, and +each pattern includes a result DAG to produce and the cost of the result (or, +inversely, the benefit of doing the replacement). A common infrastructure +efficiently finds and perform the rewrites. + +While this concept is simple, the details are more nuanced. This proposal +defines and explores a set of abstractions that we feel can solve a wide range +of different problems, and can be applied to many different sorts of problems +that MLIR is - and is expected to - face over time. We do this by separating the +pattern definition and matching algorithm from the "driver" of the computation +loop, and make space for the patterns to be defined declaratively in the future. + +## Related Work + +There is a huge amount of related work to consider, given that pretty much every +compiler in existence has to solve this problem many times over. Here are a few +graph rewrite systems we have used, along with the pros and cons of this related +work. One unifying problem with all of these is that these systems are only +trying to solve one particular and usually narrow problem: our proposal would +like to solve many of these problems with a single infrastructure. Of these, the +most similar design to our proposal is the LLVM DAG-to-DAG instruction selection +algorithm at the end. + +### Constant folding + +A degenerate but pervasive case of DAG-to-DAG pattern matching is constant +folding: given an operation whose operands contain constants can often be folded +to a result constant value. + +MLIR already has constant folding routines which provide a simpler API than a +general DAG-to-DAG pattern matcher, and we expect it to remain because the +simpler contract makes it applicable in some cases that a generic matcher would +not. For example, a DAG-rewrite can remove arbitrary nodes in the current +function, which could invalidate iterators. Constant folding as an API does not +remove any nodes, it just provides a (list of) constant values and allows the +clients to update their data structures as necessary. + +### AST-Level Pattern Matchers + +The literature is full of source-to-source translators which transform +identities in order to improve performance (e.g. transforming `X*0` into `0`). +One large example that I'm aware of is the GCC `fold` function, which performs +[many optimizations](https://github.com/gcc-mirror/gcc/blob/master/gcc/fold-const.c) +on ASTs. Clang has +[similar routines](http://releases.llvm.org/3.5.0/tools/clang/docs/InternalsManual.html#constant-folding-in-the-clang-ast) +for simple constant folding of expressions (as required by the C++ standard) but +doesn't perform general optimizations on its ASTs. + +The primary downside of tree optimizers is that you can't see across operations +that have multiple uses. It is +[well known in literature](https://llvm.org/pubs/2008-06-LCTES-ISelUsingSSAGraphs.pdf) +that DAG pattern matching is more powerful than tree pattern matching, but OTOH, +DAG pattern matching can lead to duplication of computation which needs to be +checked for. + +### "Combiners" and other peephole optimizers + +Compilers end up with a lot of peephole optimizers for various things, e.g. the +GCC +["combine" routines](https://github.com/gcc-mirror/gcc/blob/master/gcc/combine.c) +(which try to merge two machine instructions into a single one), the LLVM +[Inst Combine](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/) +[pass](https://llvm.org/docs/Passes.html#instcombine-combine-redundant-instructions), +LLVM's +[DAG Combiner](https://github.com/llvm-mirror/llvm/blob/master/lib/CodeGen/SelectionDAG/DAGCombiner.cpp), +the Swift compiler's +[SIL Combiner](https://github.com/apple/swift/tree/master/lib/SILOptimizer/SILCombiner), +etc. These generally match one or more operations and produce zero or more +operations as a result. The LLVM +[Legalization](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/) +infrastructure has a different outer loop but otherwise works the same way. + +These passes have a lot of diversity, but also have a unifying structure: they +mostly have a worklist outer loop which visits operations. They then use the C++ +visitor pattern (or equivalent) to switch over the class of operation and +dispatch to a method. That method contains a long list of hand-written C++ code +that pattern-matches various special cases. LLVM introduced a "match" function +that allows writing patterns in a somewhat more declarative style using template +metaprogramming (MLIR has similar facilities). Here's a simple example: + +```c++ + // Y - (X + 1) --> ~X + Y + if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One())))) + return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0); +``` + +Here is a somewhat more complicated one (this is not the biggest or most +complicated :) + +```c++ + // C2 is ODD + // LHS = XOR(Y,C1), Y = AND(Z,C2), C1==(C2+1) => LHS == NEG(OR(Z, ~C2)) + // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2)) + if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) + if (C1->countTrailingZeros() == 0) + if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { + Value NewOr = Builder.CreateOr(Z, ~(*C2)); + return Builder.CreateSub(RHS, NewOr, "sub"); + } +``` + +These systems are simple to set up, and pattern matching templates have some +advantages (they are extensible for new sorts of sub-patterns, look compact at +point of use). OTOH, they have lots of well known problems, for example: + +* These patterns are very error prone to write, and contain lots of + redundancies. +* The IR being matched often has identities (e.g. when matching commutative + operators) and the C++ code has to handle it manually - take a look at + [the full code](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp?view=markup#l775) + for checkForNegativeOperand that defines the second pattern). +* The matching code compiles slowly, both because it generates tons of code + and because the templates instantiate slowly. +* Adding new patterns (e.g. for count leading zeros in the example above) is + awkward and doesn't often happen. +* The cost model for these patterns is not really defined - it is emergent + based on the order the patterns are matched in code. +* They are non-extensible without rebuilding the compiler. +* It isn't practical to apply theorem provers and other tools to these + patterns - they cannot be reused for other purposes. + +In addition to structured "combiners" like these, there are lots of ad-hoc +systems like the +[LLVM Machine code peephole optimizer](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/PeepholeOptimizer.cpp?view=markup) +which are related. + +### LLVM's DAG-to-DAG Instruction Selection Infrastructure + +The instruction selection subsystem in LLVM is the result of many years worth of +iteration and discovery, driven by the need for LLVM to support code generation +for lots of targets, the complexity of code generators for modern instruction +sets (e.g. X86), and the fanatical pursuit of reusing code across targets. Eli +wrote a +[nice short overview](https://eli.thegreenplace.net/2013/02/25/a-deeper-look-into-the-llvm-code-generator-part-1) +of how this works, and the +[LLVM documentation](https://llvm.org/docs/CodeGenerator.html#select-instructions-from-dag) +describes it in more depth including its advantages and limitations. It allows +writing patterns like this. + +``` +def : Pat<(or GR64:$src, (not (add GR64:$src, 1))), + (BLCI64rr GR64:$src)>; +``` + +This example defines a matcher for the +["blci" instruction](https://en.wikipedia.org/wiki/Bit_Manipulation_Instruction_Sets#TBM_\(Trailing_Bit_Manipulation\)) +in the +[X86 target description](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrInfo.td?view=markup), +there are many others in that file (look for `Pat<>` patterns, since they aren't +entangled in details of the compiler like assembler/disassembler generation +logic). + +For our purposes, there is much to like about this system, for example: + +* It is defined in a declarative format. +* It is extensible to target-defined operations. +* It automates matching across identities, like commutative patterns. +* It allows custom abstractions and intense factoring of target-specific + commonalities. +* It generates compact code - it compiles into a state machine, which is + interpreted. +* It allows the instruction patterns to be defined and reused for multiple + purposes. +* The patterns are "type checked" at compile time, detecting lots of bugs + early and eliminating redundancy from the pattern specifications. +* It allows the use of general C++ code for weird/complex cases. + +While there is a lot that is good here, there is also a lot of bad things: + +* All of this machinery is only applicable to instruction selection. Even + directly adjacent problems like the DAGCombiner and Legalizer can't use it. +* This isn't extensible at compiler runtime, you have to rebuild the compiler + to extend it. +* The error messages when failing to match a pattern + [are not exactly optimal](https://www.google.com/search?q=llvm+cannot+select). +* It has lots of implementation problems and limitations (e.g. can't write a + pattern for a multi-result operation) as a result of working with the + awkward SelectionDAG representation and being designed and implemented + lazily. +* This stuff all grew organically over time and has lots of sharp edges. + +### Summary + +MLIR will face a wide range of pattern matching and graph rewrite problems, and +one of the major advantages of having a common representation for code at +multiple levels that it allows us to invest in - and highly leverage - a single +infra for doing this sort of work. + +## Goals + +This proposal includes support for defining pattern matching and rewrite +algorithms on MLIR. We'd like these algorithms to encompass many problems in the +MLIR space, including 1-to-N expansions (e.g. as seen in the TF/XLA bridge when +lowering a "tf.AddN" to multiple "add" HLOs), M-to-1 patterns (as seen in +Grappler optimization passes, e.g. that convert multiple/add into a single +muladd op), as well as general M-to-N patterns (e.g. instruction selection for +target instructions). Patterns should have a cost associated with them, and the +common infrastructure should be responsible for sorting out the lowest cost +match for a given application. + +We separate the task of picking a particular locally optimal pattern from a +given root node, the algorithm used to rewrite an entire graph given a +particular set of goals, and the definition of the patterns themselves. We do +this because DAG tile pattern matching is NP complete, which means that there +are no known polynomial time algorithms to optimally solve this problem. +Additionally, we would like to support iterative rewrite algorithms that +progressively transform the input program through multiple steps. Furthermore, +we would like to support many different sorts of clients across the MLIR stack, +and they may have different tolerances for compile time cost, different demands +for optimality, and other algorithmic goals or constraints. + +We aim for MLIR transformations to be easy to implement and reduce the +likelihood for compiler bugs. We expect there to be a very very large number of +patterns that are defined over time, and we believe that these sorts of patterns +will have a very large number of legality/validity constraints - many of which +are difficult to reason about in a consistent way, may be target specific, and +whose implementation may be particularly bug-prone. As such, we aim to design the +API around pattern definition to be simple, resilient to programmer errors, and +allow separation of concerns between the legality of the nodes generated from +the idea of the pattern being defined. + +Finally, error handling is a topmost concern: in addition to allowing patterns +to be defined in a target-independent way that may not apply for all hardware, +we also want failure for any pattern to match to be diagnosable in a reasonable +way. To be clear, this is not a solvable problem in general - the space of +malfunction is too great to be fully enumerated and handled optimally, but there +are better and worse ways to handle the situation. MLIR is already designed to +represent the provenance of an operation well. This project aims to propagate +that provenance information precisely, as well as diagnose pattern match +failures with the rationale for why a set of patterns do not apply. + +### Non goals + +This proposal doesn't aim to solve all compiler problems, it is simply a +DAG-to-DAG pattern matching system, starting with a greedy driver algorithm. +Compiler algorithms that require global dataflow analysis (e.g. common +subexpression elimination, conditional constant propagation, and many many +others) will not be directly solved by this infrastructure. + +This proposal is limited to DAG patterns, which (by definition) prevent the +patterns from seeing across cycles in a graph. In an SSA-based IR like MLIR, +this means that these patterns don't see across PHI nodes / basic block +arguments. We consider this acceptable given the set of problems we are trying +to solve - we don't know of any other system that attempts to do so, and +consider the payoff of worrying about this to be low. + +This design includes the ability for DAG patterns to have associated costs +(benefits), but those costs are defined in terms of magic numbers (typically +equal to the number of nodes being replaced). For any given application, the +units of magic numbers will have to be defined. + +## Overall design + +We decompose the problem into four major pieces: + +1. the code that is used to define patterns to match, cost, and their + replacement actions +1. the driver logic to pick the best match for a given root node +1. the client that is implementing some transformation (e.g. a combiner) +1. (future) the subsystem that allows patterns to be described with a + declarative syntax, which sugars step #1. + +We sketch the first three of these pieces, each in turn. This is not intended to +be a concrete API proposal, merely to describe the design + +### Defining Patterns + +Each pattern will be an instance of a mlir::Pattern class, whose subclasses +implement methods like this. Note that this API is meant for exposition, the +actual details are different for efficiency and coding standards reasons (e.g. +the memory management of `PatternState` is not specified below, etc): + +```c++ +class Pattern { + /// Return the benefit (the inverse of "cost") of matching this pattern. The + /// benefit of a Pattern is always static - rewrites that may have dynamic + /// benefit can be instantiated multiple times (different Pattern instances) + /// for each benefit that they may return, and be guarded by different match + /// condition predicates. + PatternBenefit getBenefit() const { return benefit; } + + /// Return the root node that this pattern matches. Patterns that can + /// match multiple root types are instantiated once per root. + OperationName getRootKind() const { return rootKind; } + + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). On failure, this + /// returns a None value. On success it a (possibly null) pattern-specific + /// state wrapped in a Some. This state is passed back into its rewrite + /// function if this match is selected. + virtual Optional match(Operation *op) const = 0; + + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// rewriter. If an unexpected error is encountered (an internal + /// compiler error), it is emitted through the normal MLIR diagnostic + /// hooks and the IR is left in a valid state. + virtual void rewrite(Operation *op, PatternState *state, + PatternRewriter &rewriter) const; +}; +``` + +In practice, the first patterns we implement will directly subclass and +implement this stuff, but we will define some helpers to reduce boilerplate. +When we have a declarative way to describe patterns, this should be +automatically generated from the description. + +Instances of `Pattern` have a benefit that is static upon construction of the +pattern instance, but may be computed dynamically at pattern initialization +time, e.g. allowing the benefit to be derived from domain specific information, +like the target architecture). This limitation allows us MLIR to (eventually) +perform pattern fusion and compile patterns into an efficient state machine, and +[Thier, Ertl, and Krall](https://dl.acm.org/citation.cfm?id=3179501) have shown +that match predicates eliminate the need for dynamically computed costs in +almost all cases: you can simply instantiate the same pattern one time for each +possible cost and use the predicate to guard the match. + +The two-phase nature of this API (match separate from rewrite) is important for +two reasons: 1) some clients may want to explore different ways to tile the +graph, and only rewrite after committing to one tiling. 2) We want to support +runtime extensibility of the pattern sets, but want to be able to statically +compile the bulk of known patterns into a state machine at "compiler compile +time". Both of these reasons lead to us needing to match multiple patterns +before committing to an answer. + +### Picking and performing a replacement + +In the short term, this API can be very simple, something like this can work and +will be useful for many clients: + +```c++ +class PatternMatcher { + // Create a pattern matcher with a bunch of patterns. This constructor + // looks across all of the specified patterns, and builds an internal + // data structure that allows efficient matching. + PatternMatcher(ArrayRef patterns); + + // Given a specific operation, see if there is some rewrite that is + // interesting. If so, return success and return the list of new + // operations that were created. If not, return failure. + bool matchAndRewrite(Operation *op, + SmallVectorImpl &newlyCreatedOps); +}; +``` + +In practice the interesting part of this class is the acceleration structure it +builds internally. It buckets up the patterns by root operation, and sorts them +by their static benefit. When performing a match, it tests any dynamic patterns, +then tests statically known patterns from highest to lowest benefit. + +### First Client: A Greedy Worklist Combiner + +We expect that there will be lots of clients for this, but a simple greedy +worklist-driven combiner should be powerful enough to serve many important ones, +including the +[TF2XLA op expansion logic](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla/kernels), +many of the pattern substitution passes of the +[TOCO compiler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco) +for TF-Lite, many +[Grappler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/grappler) +passes, and other general performance optimizations for applying identities. + +The structure of this algorithm is straight-forward, here is pseudo code: + +* Walk a function in preorder, adding each operation to a worklist. +* While the worklist is non-empty, pull something off the back (processing + things generally in postorder) + * Perform matchAndRewrite on the operation. If failed, continue to the + next operation. + * On success, add the newly created ops to the worklist and continue. + +## Future directions + +It is important to get implementation and usage experience with this, and many +patterns can be defined using this sort of framework. Over time, we can look to +make it easier to declare patterns in a declarative form (e.g. with the LLVM +tblgen tool or something newer/better). Once we have that, we can define an +internal abstraction for describing the patterns to match, allowing better high +level optimization of patterns (including fusion of the matching logic across +patterns, which the LLVM instruction selector does) and allow the patterns to be +defined without rebuilding the compiler itself. diff --git a/mlir/docs/Glossary.md b/mlir/docs/Glossary.md new file mode 100644 index 0000000000000000000000000000000000000000..542d3756ac70a0830724d55cd531099c1e1b84fb --- /dev/null +++ b/mlir/docs/Glossary.md @@ -0,0 +1,174 @@ +# MLIR Glossary + +This glossary contains definitions of MLIR-specific terminology. It is intended +to be a quick reference document. For terms which are well-documented elsewhere, +definitions are kept brief and the header links to the more in-depth +documentation. + + + +#### [Block](LangRef.md#blocks) + +A sequential list of operations without control flow. + +Also called a [basic block](https://en.wikipedia.org/wiki/Basic_block). + +#### Conversion + +The transformation of code represented in one dialect into a semantically +equivalent representation in another dialect (i.e. inter-dialect conversion) or +the same dialect (i.e. intra-dialect conversion). + +In the context of MLIR, conversion is distinct from [translation](#translation). +Conversion refers to a transformation between (or within) dialects, but all +still within MLIR, whereas translation refers to a transformation between MLIR +and an external representation. + +#### [Declarative Rewrite Rule](DeclarativeRewrites.md) (DRR) + +A [rewrite rule](https://en.wikipedia.org/wiki/Graph_rewriting) which can be +defined declaratively (e.g. through specification in a +[TableGen](https://llvm.org/docs/TableGen/) record). At compiler build time, +these rules are expanded into an equivalent `mlir::RewritePattern` subclass. + +#### [Dialect](LangRef.md#dialects) + +A dialect is a grouping of functionality which can be used to extend the MLIR +system. + +A dialect creates a unique `namespace` within which new +[operations](#operation-op), [attributes](LangRef.md#attributes), and +[types](LangRef.md#type-system) are defined. This is the fundamental method by +which to extend MLIR. + +In this way, MLIR is a meta-IR: its extensible framework allows it to be +leveraged in many different ways (e.g. at different levels of the compilation +process). Dialects provide an abstraction for the different uses of MLIR while +recognizing that they are all a part of the meta-IR that is MLIR. + +The tutorial provides an example of +[interfacing with MLIR](Tutorials/Toy/Ch-2.md#interfacing-with-mlir) in this +way. + +(Note that we have intentionally selected the term "dialect" instead of +"language", as the latter would wrongly suggest that these different namespaces +define entirely distinct IRs.) + +#### Export + +To transform code represented in MLIR into a semantically equivalent +representation which is external to MLIR. + +The tool that performs such a transformation is called an exporter. + +See also: [translation](#translation). + +#### [Function](LangRef.md#functions) + +An [operation](#operation-op) with a name containing one [region](#region). + +The region of a function is not allowed to implicitly capture values defined +outside of the function, and all external references must use function arguments +or attributes that establish a symbolic connection. + +#### Import + +To transform code represented in an external representation into a semantically +equivalent representation in MLIR. + +The tool that performs such a transformation is called an importer. + +See also: [translation](#translation). + +#### Legalization + +The process of transforming operations into a semantically equivalent +representation which adheres to the requirements set by the +[conversion target](DialectConversion.md#conversion-target). + +That is, legalization is accomplished if and only if the new representation +contains only operations which are legal, as specified in the conversion target. + +#### Lowering + +The process of transforming a higher-level representation of an operation into a +lower-level, but semantically equivalent, representation. + +In MLIR, this is typically accomplished through +[dialect conversion](DialectConversion.md). This provides a framework by which +to define the requirements of the lower-level representation, called the +[conversion target](DialectConversion.md#conversion-target), by specifying which +operations are legal versus illegal after lowering. + +See also: [legalization](#legalization). + +#### [Module](LangRef.md#module) + +An [operation](#operation-op) which contains a single region containing a single +block that is comprised of operations. + +This provides an organizational structure for MLIR operations, and is the +expected top-level operation in the IR: the textual parser returns a Module. + +#### [Operation](LangRef.md#operations) (op) + +A unit of code in MLIR. Operations are the building blocks for all code and +computations represented by MLIR. They are fully extensible (there is no fixed +list of operations) and have application-specific semantics. + +An operation can have zero or more [regions](#region). Note that this creates a +nested IR structure, as regions consist of blocks, which in turn, consist of a +list of operations. + +In MLIR, there are two main classes related to operations: `Operation` and `Op`. +Operation is the actual opaque instance of the operation, and represents the +general API into an operation instance. An `Op` is the base class of a derived +operation, like `ConstantOp`, and acts as smart pointer wrapper around a +`Operation*` + +#### [Region](LangRef.md#regions) + +A [CFG](https://en.wikipedia.org/wiki/Control-flow_graph) of MLIR +[blocks](#block). + +#### Round-trip + +The process of converting from a source format to a target format and then back +to the source format. + +This is a good way of gaining confidence that the target format richly models +the source format. This is particularly relevant in the MLIR context, since +MLIR's multi-level nature allows for easily writing target dialects that model a +source format (such as TensorFlow GraphDef or another non-MLIR format) +faithfully and have a simple conversion procedure. Further cleanup/lowering can +be done entirely within the MLIR representation. This separation - making the +[importer](#import) as simple as possible and performing all further +cleanups/lowering in MLIR - has proven to be a useful design pattern. + +#### [Terminator operation](LangRef.md#terminator-operations) + +An [operation](#operation-op) which *must* terminate a [block](#block). +Terminator operations are a special category of operations. + +#### Transitive lowering + +An A->B->C [lowering](#lowering); that is, a lowering in which multiple patterns +may be applied in order to fully transform an illegal operation into a set of +legal ones. + +This provides the flexibility that the [conversion](#conversion) framework may +perform the lowering in multiple stages of applying patterns (which may utilize +intermediate patterns not in the conversion target) in order to fully legalize +an operation. This is accomplished through +[partial conversion](DialectConversion.md#modes-of-conversion). + +#### Translation + +The transformation of code represented in an external (non-MLIR) representation +into a semantically equivalent representation in MLIR (i.e. +[importing](#import)), or the inverse (i.e. [exporting](#export)). + +In the context of MLIR, translation is distinct from [conversion](#conversion). +Translation refers to a transformation between MLIR and an external +representation, whereas conversion refers to a transformation within MLIR +(between or within dialects). diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md new file mode 100644 index 0000000000000000000000000000000000000000..f413cac28bb00227db8158f825b691ca95ebcd9d --- /dev/null +++ b/mlir/docs/Interfaces.md @@ -0,0 +1,200 @@ +# Introduction to MLIR Interfaces + +MLIR is generic and very extensible; it allows for opaquely representing many +different dialects that have their own operations, attributes, types, and so on. +This allows for dialects to be very expressive in their semantics and for MLIR +to capture many different levels of abstraction. The downside to this is that +transformations and analyses must be extremely conservative about the operations +that they encounter, and must special-case the different dialects that they +support. To combat this, MLIR provides the concept of `interfaces`. + +## Motivation + +Interfaces provide a generic way of interacting with the IR. The goal is to be +able to express transformations/analyses in terms of these interfaces without +encoding specific knowledge about the exact operation or dialect involved. This +makes the compiler more extensible by allowing the addition of new dialects and +operations in a decoupled way with respect to the implementation of +transformations/analyses. + +### Dialect Interfaces + +Dialect interfaces are generally useful for transformation passes or analyses +that want to opaquely operate on operations, even *across* dialects. These +interfaces generally involve wide coverage over the entire dialect and are only +used for a handful of transformations/analyses. In these cases, registering the +interface directly on each operation is overly complex and cumbersome. The +interface is not core to the operation, just to the specific transformation. An +example of where this type of interface would be used is inlining. Inlining +generally queries high-level information about the operations within a dialect, +like legality and cost modeling, that often is not specific to one operation. + +A dialect interface can be defined by inheriting from the CRTP base class +`DialectInterfaceBase::Base`. This class provides the necessary utilities for +registering an interface with the dialect so that it can be looked up later. +Once the interface has been defined, dialects can override it using +dialect-specific information. The interfaces defined by a dialect are registered +in a similar mechanism to Attributes, Operations, Types, etc. + +```c++ +/// Define an Inlining interface to allow for dialects to opt-in. +class DialectInlinerInterface : + public DialectInterface::Base { +public: + /// Returns true if the given region 'src' can be inlined into the region + /// 'dest' that is attached to an operation registered to the current dialect. + /// 'valueMapping' contains any remapped values from within the 'src' region. + /// This can be used to examine what values will replace entry arguments into + /// the 'src' region, for example. + virtual bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &valueMapping) const { + return false; + } +}; + +/// Override the inliner interface to add support for inlining affine +/// operations. +struct AffineInlinerInterface : public DialectInlinerInterface { + /// Affine structures have specific inlining constraints. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &valueMapping) const final { + ... + } +}; + +/// Register the interface with the dialect. +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) ... { + addInterfaces(); +} +``` + +Once registered, these interfaces can be opaquely queried from the dialect by +the transformation/analysis that wants to use them: + +```c++ +Dialect *dialect = ...; +if (auto *interface = dialect->getInterface()) + ... // The dialect provides this interface. +``` + +#### DialectInterfaceCollections + +An additional utility is provided via DialectInterfaceCollection. This CRTP +class allows for collecting all of the dialects that have registered a given +interface within the context. + +```c++ +class InlinerInterface : public + DialectInterfaceCollection { + /// The hooks for this class mirror the hooks for the DialectInlinerInterface, + /// with default implementations that call the hook on the interface for a + /// given dialect. + virtual bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &valueMapping) const { + auto *handler = getInterfaceFor(dest->getContainingOp()); + return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; + } +}; + +MLIRContext *ctx = ...; +InlinerInterface interface(ctx); +if(!interface.isLegalToInline(...)) + ... +``` + +### Operation Interfaces + +Operation interfaces, as the name suggests, are those registered at the +Operation level. These interfaces provide an opaque view into derived operations +by providing a virtual interface that must be implemented. As an example, the +`Linalg` dialect may implement an interface that provides general queries about +some of the dialects library operations. These queries may provide things like: +the number of parallel loops; the number of inputs and outputs; etc. + +Operation interfaces are defined by overriding the CRTP base class +`OpInterface`. This class takes, as a template parameter, a `Traits` class that +defines a `Concept` and a `Model` class. These classes provide an implementation +of concept-based polymorphism, where the Concept defines a set of virtual +methods that are overridden by the Model that is templated on the concrete +operation type. It is important to note that these classes should be pure in +that they contain no non-static data members. Operations that wish to override +this interface should add the provided trait `OpInterface<..>::Trait` upon +registration. + +```c++ +struct ExampleOpInterfaceTraits { + /// Define a base concept class that defines the virtual interface that needs + /// to be overridden. + struct Concept { + virtual ~Concept(); + virtual unsigned getNumInputs(Operation *op) = 0; + }; + + /// Define a model class that specializes a concept on a given operation type. + template + struct Model : public Concept { + /// Override the method to dispatch on the concrete operation. + unsigned getNumInputs(Operation *op) final { + return llvm::cast(op).getNumInputs(); + } + }; +}; + +class ExampleOpInterface : public OpInterface { +public: + /// Use base class constructor to support LLVM-style casts. + using OpInterface::OpInterface; + + /// The interface dispatches to 'getImpl()', an instance of the concept. + unsigned getNumInputs() { + return getImpl()->getNumInputs(getOperation()); + } +}; + +``` + +Once the interface has been defined, it is registered to an operation by adding +the provided trait `ExampleOpInterface::Trait`. Using this interface is just +like using any other derived operation type, i.e. casting: + +```c++ +/// When defining the operation, the interface is registered via the nested +/// 'Trait' class provided by the 'OpInterface<>' base class. +class MyOp : public Op { +public: + /// The definition of the interface method on the derived operation. + unsigned getNumInputs() { return ...; } +}; + +/// Later, we can query if a specific operation(like 'MyOp') overrides the given +/// interface. +Operation *op = ...; +if (ExampleOpInterface example = dyn_cast(op)) + llvm::errs() << "num inputs = " << example.getNumInputs() << "\n"; +``` + +#### Utilizing the ODS Framework + +Operation interfaces require a bit of boiler plate to connect all of the pieces +together. The ODS(Operation Definition Specification) framework provides +simplified mechanisms for +[defining interfaces](OpDefinitions.md#operation-interfaces). + +As an example, using the ODS framework would allow for defining the example +interface above as: + +```tablegen +def ExampleOpInterface : OpInterface<"ExampleOpInterface"> { + let description = [{ + This is an example interface definition. + }]; + + let methods = [ + InterfaceMethod< + "Get the number of inputs for the current operation.", + "unsigned", "getNumInputs" + >, + ]; +} +``` diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md new file mode 100644 index 0000000000000000000000000000000000000000..da60b8b892e985d89a1670c6ef51fbe1f70a5252 --- /dev/null +++ b/mlir/docs/LangRef.md @@ -0,0 +1,1497 @@ +# MLIR Specification + +MLIR (Multi-Level IR) is a compiler intermediate representation with +similarities to traditional three-address SSA representations (like +[LLVM IR](http://llvm.org/docs/LangRef.html) or +[SIL](https://github.com/apple/swift/blob/master/docs/SIL.rst)), but which +introduces notions from polyhedral loop optimization as first-class concepts. +This hybrid design is optimized to represent, analyze, and transform high level +dataflow graphs as well as target-specific code generated for high performance +data parallel systems. Beyond its representational capabilities, its single +continuous design provides a framework to lower from dataflow graphs to +high-performance target-specific code. + +This document defines and describes the key concepts in MLIR, and is intended to +be a dry reference document - the [rationale documentation](Rationale.md), +[glossary](Glossary.md), and other content are hosted elsewhere. + +MLIR is designed to be used in three different forms: a human-readable textual +form suitable for debugging, an in-memory form suitable for programmatic +transformations and analysis, and a compact serialized form suitable for storage +and transport. The different forms all describe the same semantic content. This +document describes the human-readable textual form. + +[TOC] + +## High-Level Structure + +MLIR is an +[SSA-based](https://en.wikipedia.org/wiki/Static_single_assignment_form) IR, +which means that values are defined before use and have scope defined by their +dominance relations. Operations may produce zero or more results, and each is a +distinct SSA value with its own type defined by the [type system](#type-system). + +The unit of code in MLIR is an [Operation](#operations). Operations allow for +representing many different concepts: allocating buffers, producing views to +transform them, target-independent arithmetic, target-specific operations, and +even arbitrary user-defined high-level operations including the +[Module](#module) and [Function](#functions) operations. Operations may contain +[Regions](#regions) that represent a Control Flow Graph (CFG) of +[Blocks](#blocks), that contain operations and end with a +[terminator operation](#terminator-operations) (like branches). + +Here's an example of an MLIR module: + +```mlir +// Compute A*B using an implementation of multiply kernel and print the +// result using a TensorFlow op. The dimensions of A and B are partially +// known. The shapes are assumed to match. +func @mul(%A: tensor<100x?xf32>, %B: tensor) -> (tensor<100x50xf32>) { + // Compute the inner dimension of %A using the dim operation. + %n = dim %A, 1 : tensor<100x?xf32> + + // Allocate addressable "buffers" and copy tensors %A and %B into them. + %A_m = alloc(%n) : memref<100x?xf32> + tensor_store %A to %A_m : memref<100x?xf32> + + %B_m = alloc(%n) : memref + tensor_store %B to %B_m : memref + + // Call function @multiply passing memrefs as arguments, + // and getting returned the result of the multiplication. + %C_m = call @multiply(%A_m, %B_m) + : (memref<100x?xf32>, memref) -> (memref<100x50xf32>) + + dealloc %A_m : memref<100x?xf32> + dealloc %B_m : memref + + // Load the buffer data into a higher level "tensor" value. + %C = tensor_load %C_m : memref<100x50xf32> + dealloc %C_m : memref<100x50xf32> + + // Call TensorFlow built-in function to print the result tensor. + "tf.Print"(%C){message: "mul result"} + : (tensor<100x50xf32) -> (tensor<100x50xf32>) + + return %C : tensor<100x50xf32> +} + +// A function that multiplies two memrefs and returns the result. +func @multiply(%A: memref<100x?xf32>, %B: memref) + -> (memref<100x50xf32>) { + // Compute the inner dimension of %A. + %n = dim %A, 1 : memref<100x?xf32> + + // Allocate memory for the multiplication result. + %C = alloc() : memref<100x50xf32> + + // Multiplication loop nest. + affine.for %i = 0 to 100 { + affine.for %j = 0 to 50 { + store 0 to %C[%i, %j] : memref<100x50xf32> + affine.for %k = 0 to %n { + %a_v = load %A[%i, %k] : memref<100x?xf32> + %b_v = load %B[%k, %j] : memref + %prod = mulf %a_v, %b_v : f32 + %c_v = load %C[%i, %j] : memref<100x50xf32> + %sum = addf %c_v, %prod : f32 + store %sum, %C[%i, %j] : memref<100x50xf32> + } + } + } + return %C : memref<100x50xf32> +} +``` + +## Notation + +MLIR has a simple and unambiguous grammar, allowing it to reliably round-trip +through a textual form. This is important for development of the compiler - e.g. +for understanding the state of code as it is being transformed and writing test +cases. + +This document describes the grammar using +[Extended Backus-Naur Form (EBNF)](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form). + +This is the EBNF grammar used in this document, presented in yellow boxes. + +``` +alternation ::= expr0 | expr1 | expr2 // Either expr0 or expr1 or expr2. +sequence ::= expr0 expr1 expr2 // Sequence of expr0 expr1 expr2. +repetition0 ::= expr* // 0 or more occurrences. +repetition1 ::= expr+ // 1 or more occurrences. +optionality ::= expr? // 0 or 1 occurrence. +grouping ::= (expr) // Everything inside parens is grouped together. +literal ::= `abcd` // Matches the literal `abcd`. +``` + +Code examples are presented in blue boxes. + +```mlir +// This is an example use of the grammar above: +// This matches things like: ba, bana, boma, banana, banoma, bomana... +example ::= `b` (`an` | `om`)* `a` +``` + +### Common syntax + +The following core grammar productions are used in this document: + +``` +// TODO: Clarify the split between lexing (tokens) and parsing (grammar). +digit ::= [0-9] +hex_digit ::= [0-9a-fA-F] +letter ::= [a-zA-Z] +id-punct ::= [$._-] + +integer-literal ::= decimal-literal | hexadecimal-literal +decimal-literal ::= digit+ +hexadecimal-literal ::= `0x` hex_digit+ +float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)? +string-literal ::= `"` [^"\n\f\v\r]* `"` TODO define escaping rules +``` + +Not listed here, but MLIR does support comments. They use standard BCPL syntax, +starting with a `//` and going until the end of the line. + +### Identifiers and keywords + +Syntax: + +``` +// Identifiers +bare-id ::= (letter|[_]) (letter|digit|[_$.])* +bare-id-list ::= bare-id (`,` bare-id)* +ssa-id ::= `%` suffix-id +suffix-id ::= (digit+ | ((letter|id-punct) (letter|id-punct|digit)*)) + +symbol-ref-id ::= `@` (suffix-id | string-literal) +ssa-id-list ::= ssa-id (`,` ssa-id)* + +// Uses of an SSA value, e.g. in an operand list to an operation. +ssa-use ::= ssa-id +ssa-use-list ::= ssa-use (`,` ssa-use)* +``` + +Identifiers name entities such as SSA values, types and functions, and are +chosen by the writer of MLIR code. Identifiers may be descriptive (e.g. +`%batch_size`, `@matmul`), or may be non-descriptive when they are +auto-generated (e.g. `%23`, `@func42`). Identifier names for SSA values may be +used in an MLIR text file but are not persisted as part of the IR - the printer +will give them anonymous names like `%42`. + +MLIR guarantees identifiers never collide with keywords by prefixing identifiers +with a sigil (e.g. `%`, `#`, `@`, `^`, `!`). In certain unambiguous contexts +(e.g. affine expressions), identifiers are not prefixed, for brevity. New +keywords may be added to future versions of MLIR without danger of collision +with existing identifiers. + +The scope of SSA values is defined based on the standard definition of +[dominance](https://en.wikipedia.org/wiki/Dominator_\(graph_theory\)). Argument +identifiers in mapping functions are in scope for the mapping body. Function +identifiers and mapping identifiers are visible across the entire module. + +## Dialects + +Dialects are the mechanism by which to engage with and extend the MLIR +ecosystem. They allow for defining new [operations](#operations), as well as +[attributes](#attributes) and [types](#type-system). Each dialect is given a +unique `namespace` that is prefixed to each defined attribute/operation/type. +For example, the [Affine dialect](Dialects/Affine.md) defines the namespace: +`affine`. + +MLIR allows for multiple dialects, even those outside of the main tree, to +co-exist together within one module. Dialects are produced and consumed by +certain passes. MLIR provides a [framework](DialectConversion.md) to convert +between, and within, different dialects. + +A few of the dialects supported by MLIR: + +* [Affine dialect](Dialects/Affine.md) +* [GPU dialect](Dialects/GPU.md) +* [LLVM dialect](Dialects/LLVM.md) +* [SPIR-V dialect](Dialects/SPIR-V.md) +* [Standard dialect](Dialects/Standard.md) +* [Vector dialect](Dialects/Vector.md) + +### Target specific operations + +Dialects provide a modular way in which targets can expose target-specific +operations directly through to MLIR. As an example, some targets go through +LLVM. LLVM has a rich set of intrinsics for certain target-independent +operations (e.g. addition with overflow check) as well as providing access to +target-specific operations for the targets it supports (e.g. vector permutation +operations). LLVM intrinsics in MLIR are represented via operations that start +with an "llvm." name. + +Example: + +```mlir +// LLVM: %x = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %a, i16 %b) +%x:2 = "llvm.sadd.with.overflow.i16"(%a, %b) : (i16, i16) -> (i16, i1) +``` + +These operations only work when targeting LLVM as a backend (e.g. for CPUs and +GPUs), and are required to align with the LLVM definition of these intrinsics. + +## Operations + +Syntax: + +``` +operation ::= op-result-list? (generic-operation | custom-operation) + trailing-location? +generic-operation ::= string-literal '(' ssa-use-list? ')' attribute-dict? + `:` function-type +custom-operation ::= bare-id custom-operation-format +op-result-list ::= op-result (`,` op-result)* `=` +op-result ::= ssa-id (`:` integer-literal) +successor-list ::= successor (`,` successor)* +successor ::= caret-id (`:` bb-arg-list)? +region-list ::= region (`,` region)* +trailing-location ::= (`loc` `(` location `)`)? +``` + +MLIR introduces a uniform concept called _operations_ to enable describing many +different levels of abstractions and computations. Operations in MLIR are fully +extensible (there is no fixed list of operations) and have application-specific +semantics. For example, MLIR supports +[target-independent operations](Dialects/Standard.md#memory-operations), +[affine operations](Dialects/Affine.md), and +[target-specific machine operations](#target-specific-operations). + +The internal representation of an operation is simple: an operation is +identified by a unique string (e.g. `dim`, `tf.Conv2d`, `x86.repmovsb`, +`ppc.eieio`, etc), can return zero or more results, take zero or more SSA +operands, may have zero or more attributes, may have zero or more successors, +and zero or more enclosed [regions](#regions). The generic printing form +includes all these elements literally, with a function type to indicate the +types of the results and operands. + +Example: + +```mlir +// An operation that produces two results. +// The results of %result can be accessed via the `#` syntax. +%result:2 = "foo_div"() : () -> (f32, i32) + +// Pretty form that defines a unique name for each result. +%foo, %bar = "foo_div"() : () -> (f32, i32) + +// Invoke a TensorFlow function called tf.scramble with two inputs +// and an attribute "fruit". +%2 = "tf.scramble"(%result#0, %bar) {fruit: "banana"} : (f32, i32) -> f32 +``` + +In addition to the basic syntax above, dialects may register known operations. +This allows those dialects to support _custom assembly form_ for parsing and +printing operations. In the operation sets listed below, we show both forms. + +### Terminator Operations + +These are a special category of operations that *must* terminate a block, e.g. +[branches](Dialects/Standard.md#terminator-operations). These operations may +also have a list of successors ([blocks](#blocks) and their arguments). + +Example: + +```mlir +// Branch to ^bb1 or ^bb2 depending on the condition %cond. +// Pass value %v to ^bb2, but not to ^bb1. +"cond_br"(%cond)[^bb1, ^bb2(%v : index)] : (i1) -> () +``` + +### Module + +``` +module ::= `module` symbol-ref-id? (`attributes` attribute-dict)? region +``` + +An MLIR module represents an opaque top-level container operation. It contains a +single region containing a single block that is comprised of any operations. +Operations within this region must not implicitly capture values defined above +it. Modules have an optional symbol name that can be used to refer to them in +operations. + +### Functions + +An MLIR Function is an operation with a name containing one [region](#regions). +The region of a function is not allowed to implicitly capture values defined +outside of the function, and all external references must use function arguments +or attributes that establish a symbolic connection (e.g. symbols referenced by +name via a string attribute like [SymbolRefAttr](#symbol-reference-attribute)): + +``` +function ::= `func` function-signature function-attributes? function-body? + +function-signature ::= symbol-ref-id `(` argument-list `)` + (`->` function-result-list)? + +argument-list ::= (named-argument (`,` named-argument)*) | /*empty*/ +argument-list ::= (type attribute-dict? (`,` type attribute-dict?)*) | /*empty*/ +named-argument ::= ssa-id `:` type attribute-dict? + +function-result-list ::= function-result-list-parens + | non-function-type +function-result-list-parens ::= `(` `)` + | `(` function-result-list-no-parens `)` +function-result-list-no-parens ::= function-result (`,` function-result)* +function-result ::= type attribute-dict? + +function-attributes ::= `attributes` attribute-dict +function-body ::= region +``` + +An external function declaration (used when referring to a function declared in +some other module) has no body. While the MLIR textual form provides a nice +inline syntax for function arguments, they are internally represented as "block +arguments" to the first block in the region. + +Only dialect attribute names may be specified in the attribute dictionaries for +function arguments, results, or the function itself. + +Examples: + +```mlir +// External function definitions. +func @abort() +func @scribble(i32, i64, memref) -> f64 + +// A function that returns its argument twice: +func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 +} + +// A function with an argument attribute +func @example_fn_arg(%x: i32 {swift.self = unit}) + +// A function with a result attribute +func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + +// A function with an attribute +func @example_fn_attr() attributes {dialectName.attrName = false} +``` + +## Blocks + +Syntax: + +``` +block ::= block-label operation+ +block-label ::= block-id block-arg-list? `:` +block-id ::= caret-id +caret-id ::= `^` suffix-id +ssa-id-and-type ::= ssa-id `:` type + +// Non-empty list of names and types. +ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* + +block-arg-list ::= `(` ssa-id-and-type-list? `)` +``` + +A [block](https://en.wikipedia.org/wiki/Basic_block) is a sequential list of +operations without control flow (calls are not considered control flow for this +purpose) that are executed from top to bottom. The last operation in a block is +a [terminator operation](#terminator-operations), which ends the block. + +Blocks in MLIR take a list of block arguments, which represent SSA PHI nodes in +a functional notation. The arguments are defined by the block, and values are +provided for these block arguments by branches that go to the block. + +Here is a simple example function showing branches, returns, and block +arguments: + +```mlir +func @simple(i64, i1) -> i64 { +^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a + cond_br %cond, ^bb1, ^bb2 + +^bb1: + br ^bb3(%a: i64) // Branch passes %a as the argument + +^bb2: + %b = addi %a, %a : i64 + br ^bb3(%b: i64) // Branch passes %b as the argument + +// ^bb3 receives an argument, named %c, from predecessors +// and passes it on to bb4 twice. +^bb3(%c: i64): + br ^bb4(%c, %c : i64, i64) + +^bb4(%d : i64, %e : i64): + %0 = addi %d, %e : i64 + return %0 : i64 +} +``` + +**Context:** The "block argument" representation eliminates a number of special +cases from the IR compared to traditional "PHI nodes are operations" SSA IRs +(like LLVM). For example, the +[parallel copy semantics](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf) +of SSA is immediately apparent, and function arguments are no longer a special +case: they become arguments to the entry block +[[more rationale](Rationale.md#block-arguments-vs-phi-nodes)]. + +## Regions + +### Definition + +A region is a CFG of MLIR [Blocks](#blocks). Regions serve to group semantically +connected blocks, where the semantics is not imposed by the IR. Instead, the +containing operation defines the semantics of the regions it contains. Regions +do not have a name or an address, only the blocks contained in a region do. +Regions are meaningless outside of the containing entity and have no type or +attributes. + +The first block in the region cannot be a successor of any other block. The +syntax for the region is as follows: + +``` +region ::= `{` block* `}` +``` + +The function body is an example of a region: it consists of a CFG of blocks and +has additional semantic restrictions that other types of regions may not have +(block terminators must either branch to a different block, or return from a +function where the types of the `return` arguments must match the result types +of the function signature). + +### Control and Value Scoping + +Regions provide nested control isolation: it is impossible to branch to a block +within a region from outside it or to branch from within a region to a block +outside it. Similarly, it provides a natural scoping for value visibility: SSA +values defined in a region don't escape to the enclosing region, if any. By +default, a region can reference values defined outside of the region whenever it +would have been legal to use them as operands to the enclosing operation. + +Example: + +```mlir +func @accelerator_compute(i64, i1) -> i64 { +^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a + cond_br %cond, ^bb1, ^bb2 + +^bb1: + // This def for %value does not dominate ^bb2 + %value = "op.convert"(%a) : (i64) -> i64 + br ^bb3(%a: i64) // Branch passes %a as the argument + +^bb2: + "accelerator.launch"() { + ^bb0: + // Region of code nested under "accelerator.launch", it can reference %a but + // not %value. + %new_value = "accelerator.do_something"(%a) : (i64) -> () + } + // %new_value cannot be referenced outside of the region + +^bb3: + ... +} +``` + +This can be further restricted using the custom verifier associated with the +enclosing operation, for example, disallowing references to values defined +outside the region completely. + +### Control Flow + +Regions are Single-Entry-Multiple-Exit (SEME). This means that control can only +flow into the first block of the region, but can flow out of the region at the +end of any of the contained blocks (This behavior is similar to that of a +function body in most programming languages). When exiting a Region, control is +returned to the enclosing operation. + +The enclosing operation determines the way in which control is transmitted into +the entry block of a Region. The successor to a region’s exit points may not +necessarily exist: for example a call to a function that does not return. +Concurrent or asynchronous execution of regions is unspecified. Operations may +define specific rules of execution, e.g. sequential loops or switch cases. + +A Region may also enter another region within the enclosing operation. If an +operation has multiple regions, the semantics of the operation defines into +which regions the control flows and in which order, if any. An operation may +transmit control into regions that were specified in other operations, in +particular those that defined the values the given operation uses. Thus such +operations can be treated opaquely in the enclosing control flow graph, +providing a level of control flow isolation similar to that of the call +operation. + +#### Closure + +Regions allow defining an operation that creates a closure, for example by +“boxing” the body of the region into a value they produce. It remains up to the +operation to define its semantics. Note that if an operation triggers +asynchronous execution of the region, it is under the responsibility of the +operation caller to wait for the region to be executed guaranteeing that any +directly used values remain live. + +### Arguments and Results + +The arguments of the first block of a region are treated as arguments of the +region. The source of these arguments is defined by the semantics of the parent +operation. They may correspond to some of the values the operation itself uses. + +Regions produce a (possibly empty) list of values. The operation semantics +defines the relation between the region results and the operation results. + +## Type System + +Each SSA value in MLIR has a type defined by the type system below. There are a +number of primitive types (like integers) and also aggregate types for tensors +and memory buffers. MLIR [standard types](#standard-types) do not include +structures, arrays, or dictionaries. + +MLIR has an open type system (i.e. there is no fixed list of types), and types +may have application-specific semantics. For example, MLIR supports a set of +[dialect types](#dialect-types). + +``` +type ::= type-alias | dialect-type | standard-type + +type-list-no-parens ::= type (`,` type)* +type-list-parens ::= `(` `)` + | `(` type-list-no-parens `)` + +// This is a common way to refer to an SSA value with a specified type. +ssa-use-and-type ::= ssa-use `:` type + +// Non-empty list of names and types. +ssa-use-and-type-list ::= ssa-use-and-type (`,` ssa-use-and-type)* +``` + +### Type Aliases + +``` +type-alias-def ::= '!' alias-name '=' 'type' type +type-alias ::= '!' alias-name +``` + +MLIR supports defining named aliases for types. A type alias is an identifier +that can be used in the place of the type that it defines. These aliases *must* +be defined before their uses. Alias names may not contain a '.', since those +names are reserved for [dialect types](#dialect-types). + +Example: + +```mlir +!avx_m128 = type vector<4 x f32> + +// Using the original type. +"foo"(%x) : vector<4 x f32> -> () + +// Using the type alias. +"foo"(%x) : !avx_m128 -> () +``` + +### Dialect Types + +Similarly to operations, dialects may define custom extensions to the type +system. + +``` +dialect-namespace ::= bare-id + +opaque-dialect-item ::= dialect-namespace '<' string-literal '>' + +pretty-dialect-item ::= dialect-namespace '.' pretty-dialect-item-lead-ident + pretty-dialect-item-body? + +pretty-dialect-item-lead-ident ::= '[A-Za-z][A-Za-z0-9._]*' +pretty-dialect-item-body ::= '<' pretty-dialect-item-contents+ '>' +pretty-dialect-item-contents ::= pretty-dialect-item-body + | '(' pretty-dialect-item-contents+ ')' + | '[' pretty-dialect-item-contents+ ']' + | '{' pretty-dialect-item-contents+ '}' + | '[^[<({>\])}\0]+' + +dialect-type ::= '!' opaque-dialect-item +dialect-type ::= '!' pretty-dialect-item +``` + +Dialect types can be specified in a verbose form, e.g. like this: + +```mlir +// LLVM type that wraps around llvm IR types. +!llvm<"i32*"> + +// Tensor flow string type. +!tf.string + +// Complex type +!foo<"something"> + +// Even more complex type +!foo<"something>>"> +``` + +Dialect types that are simple enough can use the pretty format, which is a +lighter weight syntax that is equivalent to the above forms: + +```mlir +// Tensor flow string type. +!tf.string + +// Complex type +!foo.something +``` + +Sufficiently complex dialect types are required to use the verbose form for +generality. For example, the more complex type shown above wouldn't be valid in +the lighter syntax: `!foo.something>>` because it contains characters +that are not allowed in the lighter syntax, as well as unbalanced `<>` +characters. + +See [here](DefiningAttributesAndTypes.md) to learn how to define dialect types. + +### Standard Types + +Standard types are a core set of [dialect types](#dialect-types) that are +defined in a builtin dialect and thus available to all users of MLIR. + +``` +standard-type ::= complex-type + | float-type + | function-type + | index-type + | integer-type + | memref-type + | none-type + | tensor-type + | tuple-type + | vector-type +``` + +#### Complex Type + +Syntax: + +``` +complex-type ::= `complex` `<` type `>` +``` + +The value of `complex` type represents a complex number with a parameterized +element type, which is composed of a real and imaginary value of that element +type. The element must be a floating point or integer scalar type. + +Examples: + +```mlir +complex +complex +``` + +#### Floating Point Types + +Syntax: + +``` +// Floating point. +float-type ::= `f16` | `bf16` | `f32` | `f64` +``` + +MLIR supports float types of certain widths that are widely used as indicated +above. + +#### Function Type + +Syntax: + +``` +// MLIR functions can return multiple values. +function-result-type ::= type-list-parens + | non-function-type + +function-type ::= type-list-parens `->` function-result-type +``` + +MLIR supports first-class functions: for example, the +[`constant` operation](Dialects/Standard.md#constant-operation) produces the +address of a function as an SSA value. This SSA value may be passed to and +returned from functions, merged across control flow boundaries with +[block arguments](#blocks), and called with the +[`call_indirect` operation](Dialects/Standard.md#call-indirect-operation). + +Function types are also used to indicate the arguments and results of +[operations](#operations). + +#### Index Type + +Syntax: + +``` +// Target word-sized integer. +index-type ::= `index` +``` + +The `index` type is a signless integer whose size is equal to the natural +machine word of the target ([rationale](Rationale.md#signless-types)) and is +used by the affine constructs in MLIR. Unlike fixed-size integers, it cannot be +used as an element of vector, tensor or memref type +([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)). + +**Rationale:** integers of platform-specific bit widths are practical to express +sizes, dimensionalities and subscripts. + +#### Integer Type + +Syntax: + +``` +// Sized integers like i1, i4, i8, i16, i32. +integer-type ::= `i` [1-9][0-9]* +``` + +MLIR supports arbitrary precision integer types. Integer types are signless, but +have a designated width. + +**Rationale:** low precision integers (like `i2`, `i4` etc) are useful for +low-precision inference chips, and arbitrary precision integers are useful for +hardware synthesis (where a 13 bit multiplier is a lot cheaper/smaller than a 16 +bit one). + +TODO: Need to decide on a representation for quantized integers +([initial thoughts](Rationale.md#quantized-integer-operations)). + +#### Memref Type + +Syntax: + +``` +memref-type ::= ranked-memref-type | unranked-memref-type + +ranked-memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type + (`,` layout-specification)? | + (`,` memory-space)? `>` + +unranked-memref-type ::= `memref` `<*x` tensor-memref-element-type + (`,` memory-space)? `>` + +stride-list ::= `[` (dimension (`,` dimension)*)? `]` +strided-layout ::= `offset:` dimension `,` `strides: ` stride-list +layout-specification ::= semi-affine-map | strided-layout +memory-space ::= integer-literal /* | TODO: address-space-id */ +``` + +A `memref` type is a reference to a region of memory (similar to a buffer +pointer, but more powerful). The buffer pointed to by a memref can be allocated, +aliased and deallocated. A memref can be used to read and write data from/to the +memory region which it references. Memref types use the same shape specifier as +tensor types. Note that `memref`, `memref<0 x f32>`, `memref<1 x 0 x f32>`, +and `memref<0 x 1 x f32>` are all different types. + +A `memref` is allowed to have an unknown rank (e.g. `memref<*xf32>`). The +purpose of unranked memrefs is to allow external library functions to receive +memref arguments of any rank without versioning the functions based on the rank. +Other uses of this type are disallowed or will have undefined behavior. + +##### Codegen of Unranked Memref + +Using unranked memref in codegen besides the case mentioned above is highly +discouraged. Codegen is concerned with generating loop nests and specialized +instructions for high-performance, unranked memref is concerned with hiding the +rank and thus, the number of enclosing loops required to iterate over the data. +However, if there is a need to code-gen unranked memref, one possible path is to +cast into a static ranked type based on the dynamic rank. Another possible path +is to emit a single while loop conditioned on a linear index and perform +delinearization of the linear index to a dynamic array containing the (unranked) +indices. While this is possible, it is expected to not be a good idea to perform +this during codegen as the cost of the translations is expected to be +prohibitive and optimizations at this level are not expected to be worthwhile. +If expressiveness is the main concern, irrespective of performance, passing +unranked memrefs to an external C++ library and implementing rank-agnostic logic +there is expected to be significantly simpler. + +Unranked memrefs may provide expressiveness gains in the future and help bridge +the gap with unranked tensors. Unranked memrefs will not be expected to be +exposed to codegen but one may query the rank of an unranked memref (a special +op will be needed for this purpose) and perform a switch and cast to a ranked +memref as a prerequisite to codegen. + +Example: + +```mlir +// With static ranks, we need a function for each possible argument type +%A = alloc() : memref<16x32xf32> %B = alloc() : +memref<16x32x64xf32> call @helper_2D(%A) : (memref<16x32xf32>)->() call +@helper_3D(%B) : (memref<16x32x64xf32>)->() + +// With unknown rank, the functions can be unified under one unranked type +%A = alloc() : memref<16x32xf32> +%B = alloc() : memref<16x32x64xf32> +// Remove rank info +%A_u = memref_cast %A : memref<16x32xf32> -> memref<*xf32> +%B_u = memref_cast %B : memref<16x32x64xf32> -> memref<*xf32> +// call same function with dynamic ranks +call @helper(%A_u) : (memref<*xf32>)->() +call @helper(%B_u) : (memref<*xf32>)->() +``` + +The core syntax and representation of a layout specification is a +[semi-affine map](Dialects/Affine.md#semi-affine-maps). Additionally, syntactic +sugar is supported to make certain layout specifications more intuitive to read. +For the moment, a `memref` supports parsing a strided form which is converted to +a semi-affine map automatically. + +The memory space of a memref is specified by a target-specific integer index. If +no memory space is specified, then the default memory space (0) is used. The +default space is target specific but always at index 0. + +TODO: MLIR will eventually have target-dialects which allow symbolic use of +memory hierarchy names (e.g. L3, L2, L1, ...) but we have not spec'd the details +of that mechanism yet. Until then, this document pretends that it is valid to +refer to these memories by `bare-id`. + +The notionally dynamic value of a memref value includes the address of the +buffer allocated, as well as the symbols referred to by the shape, layout map, +and index maps. + +Examples of memref static type + +```mlir +// Identity index/layout map +#identity = (d0, d1) -> (d0, d1) + +// Column major layout. +#col_major = (d0, d1, d2) -> (d2, d1, d0) + +// A 2-d tiled layout with tiles of size 128 x 256. +#tiled_2d_128x256 = (d0, d1) -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256) + +// A tiled data layout with non-constant tile sizes. +#tiled_dynamic = (d0, d1)[s0, s1] -> (d0 floordiv s0, d1 floordiv s1, + d0 mod s0, d1 mod s1) + +// A layout that yields a padding on two at either end of the minor dimension. +#padded = (d0, d1) -> (d0, (d1 + 2) floordiv 2, (d1 + 2) mod 2) + + +// The dimension list "16x32" defines the following 2D index space: +// +// { (i, j) : 0 <= i < 16, 0 <= j < 32 } +// +memref<16x32xf32, #identity, memspace0> + +// The dimension list "16x4x?" defines the following 3D index space: +// +// { (i, j, k) : 0 <= i < 16, 0 <= j < 4, 0 <= k < N } +// +// where N is a symbol which represents the runtime value of the size of +// the third dimension. +// +// %N here binds to the size of the third dimension. +%A = alloc(%N) : memref<16x4x?xf32, #col_major, memspace0> + +// A 2-d dynamic shaped memref that also has a dynamically sized tiled layout. +// The memref index space is of size %M x %N, while %B1 and %B2 bind to the +// symbols s0, s1 respectively of the layout map #tiled_dynamic. Data tiles of +// size %B1 x %B2 in the logical space will be stored contiguously in memory. +// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2 +// f32 elements. +%T = alloc(%M, %N) [%B1, %B2] : memref + +// A memref that has a two-element padding at either end. The allocation size +// will fit 16 * 68 float elements of data. +%P = alloc() : memref<16x64xf32, #padded> + +// Affine map with symbol 's0' used as offset for the first dimension. +#imapS = (d0, d1) [s0] -> (d0 + s0, d1) +// Allocate memref and bind the following symbols: +// '%n' is bound to the dynamic second dimension of the memref type. +// '%o' is bound to the symbol 's0' in the affine map of the memref type. +%n = ... +%o = ... +%A = alloc (%n)[%o] : <16x?xf32, #imapS> +``` + +##### Index Space + +A memref dimension list defines an index space within which the memref can be +indexed to access data. + +##### Index + +Data is accessed through a memref type using a multidimensional index into the +multidimensional index space defined by the memref's dimension list. + +Examples + +```mlir +// Allocates a memref with 2D index space: +// { (i, j) : 0 <= i < 16, 0 <= j < 32 } +%A = alloc() : memref<16x32xf32, #imapA, memspace0> + +// Loads data from memref '%A' using a 2D index: (%i, %j) +%v = load %A[%i, %j] : memref<16x32xf32, #imapA, memspace0> +``` + +##### Index Map + +An index map is a one-to-one +[semi-affine map](Dialects/Affine.md#semi-affine-maps) that transforms a +multidimensional index from one index space to another. For example, the +following figure shows an index map which maps a 2-dimensional index from a 2x2 +index space to a 3x3 index space, using symbols `S0` and `S1` as offsets. + +![Index Map Example](includes/img/index-map.svg) + +The number of domain dimensions and range dimensions of an index map can be +different, but must match the number of dimensions of the input and output index +spaces on which the map operates. The index space is always non-negative and +integral. In addition, an index map must specify the size of each of its range +dimensions onto which it maps. Index map symbols must be listed in order with +symbols for dynamic dimension sizes first, followed by other required symbols. + +##### Layout Map + +A layout map is a [semi-affine map](Dialects/Affine.md#semi-affine-maps) which +encodes logical to physical index space mapping, by mapping input dimensions to +their ordering from most-major (slowest varying) to most-minor (fastest +varying). Therefore, an identity layout map corresponds to a row-major layout. +Identity layout maps do not contribute to the MemRef type identification and are +discarded on construction. That is, a type with an explicit identity map is +`memref(i,j)>` is strictly the same as the one without layout +maps, `memref`. + +Layout map examples: + +```mlir +// MxN matrix stored in row major layout in memory: +#layout_map_row_major = (i, j) -> (i, j) + +// MxN matrix stored in column major layout in memory: +#layout_map_col_major = (i, j) -> (j, i) + +// MxN matrix stored in a 2-d blocked/tiled layout with 64x64 tiles. +#layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64) +``` + +##### Affine Map Composition + +A memref specifies a semi-affine map composition as part of its type. A +semi-affine map composition is a composition of semi-affine maps beginning with +zero or more index maps, and ending with a layout map. The composition must be +conformant: the number of dimensions of the range of one map, must match the +number of dimensions of the domain of the next map in the composition. + +The semi-affine map composition specified in the memref type, maps from accesses +used to index the memref in load/store operations to other index spaces (i.e. +logical to physical index mapping). Each of the +[semi-affine maps](Dialects/Affine.md) and thus its composition is required to +be one-to-one. + +The semi-affine map composition can be used in dependence analysis, memory +access pattern analysis, and for performance optimizations like vectorization, +copy elision and in-place updates. If an affine map composition is not specified +for the memref, the identity affine map is assumed. + +##### Strided MemRef + +A memref may specify strides as part of its type. A stride specification is a +list of integer values that are either static or `?` (dynamic case). Strides +encode the distance, in number of elements, in (linear) memory between +successive entries along a particular dimension. A stride specification is +syntactic sugar for an equivalent strided memref representation using +semi-affine maps. For example, `memref<42x16xf32, offset: 33 strides: [1, 64]>` +specifies a non-contiguous memory region of `42` by `16` `f32` elements such +that: + +1. the minimal size of the enclosing memory region must be `33 + 42 * 1 + 16 * + 64 = 1066` elements; +2. the address calculation for accessing element `(i, j)` computes `33 + i + + 64 * j` +3. the distance between two consecutive elements along the outer dimension is + `1` element and the distance between two consecutive elements along the + outer dimension is `64` elements. + +This corresponds to a column major view of the memory region and is internally +represented as the type `memref<42x16xf32, (i, j) -> (33 + i + 64 * j)>`. + +The specification of strides must not alias: given an n-D strided memref, +indices `(i1, ..., in)` and `(j1, ..., jn)` may not refer to the same memory +address unless `i1 == j1, ..., in == jn`. + +Strided memrefs represent a view abstraction over preallocated data. They are +constructed with special ops, yet to be introduced. Strided memrefs are a +special subclass of memrefs with generic semi-affine map and correspond to a +normalized memref descriptor when lowering to LLVM. + +#### None Type + +Syntax: + +``` +none-type ::= `none` +``` + +The `none` type is a unit type, i.e. a type with exactly one possible value, +where its value does not have a defined dynamic representation. + +#### Tensor Type + +Syntax: + +``` +tensor-type ::= `tensor` `<` dimension-list tensor-memref-element-type `>` +tensor-memref-element-type ::= vector-element-type | vector-type | complex-type + +// memref requires a known rank, but tensor does not. +dimension-list ::= dimension-list-ranked | (`*` `x`) +dimension-list-ranked ::= (dimension `x`)* +dimension ::= `?` | decimal-literal +``` + +SSA values of tensor type represents aggregate N-dimensional data values, and +have a known element type. It may have an unknown rank (indicated by `*`) or may +have a fixed rank with a list of dimensions. Each dimension may be a static +non-negative decimal constant or be dynamically determined (indicated by `?`). + +The runtime representation of the MLIR tensor type is intentionally abstracted - +you cannot control layout or get a pointer to the data. For low level buffer +access, MLIR has a [`memref` type](#memref-type). This abstracted runtime +representation holds both the tensor data values as well as information about +the (potentially dynamic) shape of the tensor. The +[`dim` operation](Dialects/Standard.md#dim-operation) returns the size of a +dimension from a value of tensor type. + +Note: hexadecimal integer literals are not allowed in tensor type declarations +to avoid confusion between `0xf32` and `0 x f32`. Zero sizes are allowed in +tensors and treated as other sizes, e.g., `tensor<0 x 1 x i32>` and `tensor<1 x +0 x i32>` are different types. Since zero sizes are not allowed in some other +types, such tensors should be optimized away before lowering tensors to vectors. + +Examples: + +```mlir +// Tensor with unknown rank. +tensor<* x f32> + +// Known rank but unknown dimensions. +tensor + +// Partially known dimensions. +tensor + +// Full static shape. +tensor<17 x 4 x 13 x 4 x f32> + +// Tensor with rank zero. Represents a scalar. +tensor + +// Zero-element dimensions are allowed. +tensor<0 x 42 x f32> + +// Zero-element tensor of f32 type (hexadecimal literals not allowed here). +tensor<0xf32> +``` + +#### Tuple Type + +Syntax: + +``` +tuple-type ::= `tuple` `<` (type ( `,` type)*)? `>` +``` + +The value of `tuple` type represents a fixed-size collection of elements, where +each element may be of a different type. + +**Rationale:** Though this type is first class in the type system, MLIR provides +no standard operations for operating on `tuple` types +([rationale](Rationale.md#tuple-types)). + +Examples: + +```mlir +// Empty tuple. +tuple<> + +// Single element +tuple + +// Many elements. +tuple, i5> +``` + +#### Vector Type + +Syntax: + +``` +vector-type ::= `vector` `<` static-dimension-list vector-element-type `>` +vector-element-type ::= float-type | integer-type + +static-dimension-list ::= (decimal-literal `x`)+ +``` + +The vector type represents a SIMD style vector, used by target-specific +operation sets like AVX. While the most common use is for 1D vectors (e.g. +vector<16 x f32>) we also support multidimensional registers on targets that +support them (like TPUs). + +Vector shapes must be positive decimal integers. + +Note: hexadecimal integer literals are not allowed in vector type declarations, +`vector<0x42xi32>` is invalid because it is interpreted as a 2D vector with +shape `(0, 42)` and zero shapes are not allowed. + +## Attributes + +Syntax: + +``` +attribute-dict ::= `{` `}` + | `{` attribute-entry (`,` attribute-entry)* `}` +attribute-entry ::= dialect-attribute-entry | dependent-attribute-entry +dialect-attribute-entry ::= dialect-namespace `.` bare-id `=` attribute-value +dependent-attribute-entry ::= dependent-attribute-name `=` attribute-value +dependent-attribute-name ::= (letter|[_]) (letter|digit|[_$])* +``` + +Attributes are the mechanism for specifying constant data on operations in +places where a variable is never allowed - e.g. the index of a +[`dim` operation](Dialects/Standard.md#dim-operation), or the stride of a +convolution. They consist of a name and a concrete attribute value. The set of +expected attributes, their structure, and their interpretation are all +contextually dependent on what they are attached to. + +There are two main classes of attributes: dependent and dialect. Dependent +attributes derive their structure and meaning from what they are attached to; +e.g., the meaning of the `index` attribute on a `dim` operation is defined by +the `dim` operation. Dialect attributes, on the other hand, derive their context +and meaning from a specific dialect. An example of a dialect attribute may be a +`swift.self` function argument attribute that indicates an argument is the +self/context parameter. The context of this attribute is defined by the `swift` +dialect and not the function argument. + +Attribute values are represented by the following forms: + +``` +attribute-value ::= attribute-alias | dialect-attribute | standard-attribute +``` + +### Attribute Value Aliases + +``` +attribute-alias ::= '#' alias-name '=' attribute-value +attribute-alias ::= '#' alias-name +``` + +MLIR supports defining named aliases for attribute values. An attribute alias is +an identifier that can be used in the place of the attribute that it defines. +These aliases *must* be defined before their uses. Alias names may not contain a +'.', since those names are reserved for +[dialect attributes](#dialect-attribute-values). + +Example: + +```mlir +#map = (d0) -> (d0 + 10) + +// Using the original attribute. +%b = affine.apply (d0) -> (d0 + 10) (%a) + +// Using the attribute alias. +%b = affine.apply #map(%a) +``` + +### Dialect Attribute Values + +Similarly to operations, dialects may define custom attribute values. The +syntactic structure of these values is identical to custom dialect type values, +except that dialect attributes values are distinguished with a leading '#', +while dialect types are distinguished with a leading '!'. + +``` +dialect-attribute ::= '#' opaque-dialect-item +dialect-attribute ::= '#' pretty-dialect-item +``` + +Dialect attributes can be specified in a verbose form, e.g. like this: + +```mlir +// Complex attribute +#foo<"something"> + +// Even more complex attribute +#foo<"something>>"> +``` + +Dialect attributes that are simple enough can use the pretty format, which is a +lighter weight syntax that is equivalent to the above forms: + +```mlir +// Complex attribute +#foo.something +``` + +Sufficiently complex dialect attributes are required to use the verbose form for +generality. For example, the more complex type shown above wouldn't be valid in +the lighter syntax: `#foo.something>>` because it contains characters +that are not allowed in the lighter syntax, as well as unbalanced `<>` +characters. + +See [here](DefiningAttributesAndTypes.md) to learn how to define dialect +attribute values. + +### Standard Attribute Values + +Standard attributes are a core set of +[dialect attributes](#dialect-attribute-values) that are defined in a builtin +dialect and thus available to all users of MLIR. + +``` +standard-attribute ::= affine-map-attribute + | array-attribute + | bool-attribute + | dictionary-attribute + | elements-attribute + | float-attribute + | integer-attribute + | integer-set-attribute + | string-attribute + | symbol-ref-attribute + | type-attribute + | unit-attribute +``` + +#### AffineMap Attribute + +Syntax: + +``` +affine-map-attribute ::= affine-map +``` + +An affine-map attribute is an attribute that represents a affine-map object. + +#### Array Attribute + +Syntax: + +``` +array-attribute ::= `[` (attribute-value (`,` attribute-value)*)? `]` +``` + +An array attribute is an attribute that represents a collection of attribute +values. + +#### Boolean Attribute + +Syntax: + +``` +bool-attribute ::= bool-literal +``` + +A boolean attribute is a literal attribute that represents a one-bit boolean +value, true or false. + +#### Dictionary Attribute + +Syntax: + +``` +dictionary-attribute ::= `{` (attribute-entry (`,` attribute-entry)*)? `}` +``` + +A dictionary attribute is an attribute that represents a sorted collection of +named attribute values. The elements are sorted by name, and each name must be +unique within the collection. + +#### Elements Attributes + +Syntax: + +``` +elements-attribute ::= dense-elements-attribute + | opaque-elements-attribute + | sparse-elements-attribute +``` + +An elements attribute is a literal attribute that represents a constant +[vector](#vector-type) or [tensor](#tensor-type) value. + +##### Dense Elements Attribute + +Syntax: + +``` +dense-elements-attribute ::= `dense` `<` attribute-value `>` `:` + ( tensor-type | vector-type ) +``` + +A dense elements attribute is an elements attribute where the storage for the +constant vector or tensor value has been packed to the element bitwidth. The +element type of the vector or tensor constant must be of integer, index, or +floating point type. + +##### Opaque Elements Attribute + +Syntax: + +``` +opaque-elements-attribute ::= `opaque` `<` dialect-namespace `,` + hex-string-literal `>` `:` + ( tensor-type | vector-type ) +``` + +An opaque elements attribute is an elements attribute where the content of the +value is opaque. The representation of the constant stored by this elements +attribute is only understood, and thus decodable, by the dialect that created +it. + +Note: The parsed string literal must be in hexadecimal form. + +##### Sparse Elements Attribute + +Syntax: + +``` +sparse-elements-attribute ::= `sparse` `<` attribute-value `,` attribute-value + `>` `:` ( tensor-type | vector-type ) +``` + +A sparse elements attribute is an elements attribute that represents a sparse +vector or tensor object. This is where very few of the elements are non-zero. + +The attribute uses COO (coordinate list) encoding to represent the sparse +elements of the elements attribute. The indices are stored via a 2-D tensor of +64-bit integer elements with shape [N, ndims], which specifies the indices of +the elements in the sparse tensor that contains non-zero values. The element +values are stored via a 1-D tensor with shape [N], that supplies the +corresponding values for the indices. + +Example: + +```mlir + sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32> + +// This represents the following tensor: +/// [[1, 0, 0, 0], +/// [0, 0, 5, 0], +/// [0, 0, 0, 0]] +``` + +#### Float Attribute + +Syntax: + +``` +float-attribute ::= (float-literal (`:` float-type)?) + | (hexadecimal-literal `:` float-type) +``` + +A float attribute is a literal attribute that represents a floating point value +of the specified [float type](#floating-point-types). It can be represented in +the hexadecimal form where the hexadecimal value is interpreted as bits of the +underlying binary representation. This form is useful for representing infinity +and NaN floating point values. To avoid confusion with integer attributes, +hexadecimal literals _must_ be followed by a float type to define a float +attribute. + +Examples: + +``` +42.0 // float attribute defaults to f64 type +42.0 : f32 // float attribute of f32 type +0x7C00 : f16 // positive infinity +0x7CFF : f16 // NaN (one of possible values) +42 : f32 // Error: expected integer type +``` + +#### Integer Attribute + +Syntax: + +``` +integer-attribute ::= integer-literal ( `:` (index-type | integer-type) )? +``` + +An integer attribute is a literal attribute that represents an integral value of +the specified integer or index type. The default type for this attribute, if one +is not specified, is a 64-bit integer. + +##### Integer Set Attribute + +Syntax: + +``` +integer-set-attribute ::= affine-map +``` + +An integer-set attribute is an attribute that represents an integer-set object. + +#### String Attribute + +Syntax: + +``` +string-attribute ::= string-literal (`:` type)? +``` + +A string attribute is an attribute that represents a string literal value. + +#### Symbol Reference Attribute + +Syntax: + +``` +symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* +``` + +A symbol reference attribute is a literal attribute that represents a named +reference to an operation that is nested within an operation with the +`OpTrait::SymbolTable` trait. As such, this reference is given meaning by the +nearest parent operation containing the `OpTrait::SymbolTable` trait. It may +optionally contain a set of nested references that further resolve to a symbol +nested within a different symbol table. + +This attribute can only be held internally by +[array attributes](#array-attribute) and +[dictionary attributes](#dictionary-attribute)(including the top-level operation +attribute dictionary), i.e. no other attribute kinds such as Locations or +extended attribute kinds. If a reference to a symbol is necessary from outside +of the symbol table that the symbol is defined in, a +[string attribute](string-attribute) can be used to refer to the symbol name. + +**Rationale:** Given that MLIR models global accesses with symbol references, to +enable efficient multi-threading, it becomes difficult to effectively reason +about their uses. By restricting the places that can legally hold a symbol +reference, we can always opaquely reason about a symbols usage characteristics. + +#### Type Attribute + +Syntax: + +``` +type-attribute ::= type +``` + +A type attribute is an attribute that represents a [type object](#type-system). + +#### Unit Attribute + +``` +unit-attribute ::= `unit` +``` + +A unit attribute is an attribute that represents a value of `unit` type. The +`unit` type allows only one value forming a singleton set. This attribute value +is used to represent attributes that only have meaning from their existence. + +One example of such an attribute could be the `swift.self` attribute. This +attribute indicates that a function parameter is the self/context parameter. It +could be represented as a [boolean attribute](#boolean-attribute)(true or +false), but a value of false doesn't really bring any value. The parameter +either is the self/context or it isn't. + +```mlir +// A unit attribute defined with the `unit` value specifier. +func @verbose_form(i1) attributes {dialectName.unitAttr = unit} + +// A unit attribute can also be defined without the value specifier. +func @simple_form(i1) attributes {dialectName.unitAttr} +``` diff --git a/mlir/docs/MLIRForGraphAlgorithms.md b/mlir/docs/MLIRForGraphAlgorithms.md new file mode 100644 index 0000000000000000000000000000000000000000..ac26e5beb9b93829945e8f25a1192a390650fea1 --- /dev/null +++ b/mlir/docs/MLIRForGraphAlgorithms.md @@ -0,0 +1,403 @@ +# MLIR: Incremental Application to Graph Algorithms in ML Frameworks + +The existing documentation about MLIR focuses on long term vision, how its +pieces fit together, and the benefits of modular and composable infrastructure +in the vast and distant future. While this viewpoint appeals to some, it causes +concern for others who are more concerned about the "here and now" - why does it +make sense to make a "revolutionary" change when any individual problem can be +fixed in place? + +This document explains that adoption of MLIR to solve graph based problems +_isn't_ a revolutionary change: it is an incremental series of steps which build +on each other, each of which delivers local value. This document also addresses +some points of confusion that keep coming up. + +One note: even though a major advantage of MLIR is that it can span the full +spectrum from graph algorithms down to low-level code generation, this document +focuses on the use of MLIR for **graph-level algorithms**. MLIR will also unlock +exciting code generation opportunities (particularly given its novel approach to +integrating state of the art polyhedral techniques), but issues that touch on +MLIR's relationship to XLA, Eigen, etc, are out of scope for this particular +doc. + +This document uses TensorFlow as the example given that it is the focus of our +immediate work, but we believe that the same viewpoint could be useful for +people working in the context of other ML frameworks that may consider adopting +MLIR in the future. + +### How is MLIR relevant? + +MLIR is an overloaded acronym which unpacks as "Multi-Level Intermediate +Representation". Its high-level purpose is to provide mechanics for describing +and transforming programs and computations in a flexible way. It provides common +compiler infrastructure for things like constant folding, dead code elimination, +graph rewriting, and others - which are independent of the representational +choices picked by a given dialect (e.g. its concurrency semantics). It was built +with a specific focus on compile time and memory efficiency, accurate +propagation of source location information (important for reporting high quality +errors and warnings) and is designed for testability. + +TensorFlow has numerous subsystems (some of which are proprietary, e.g. +Tensor-RT, nGraph, CoreML, etc) as well as translation layers between these +different subsystems, and these translation layers face similar challenges. ((As +an aside, the internals of each of these subsystems could often benefit from +MLIR infrastructure, but that isn't a focus of this doc.)) + +A key observation that MLIR makes is that these subsystems often have two things +going on: they are both particular data structures and encodings (e.g. HLO +graphs, TF-Lite's flat buffer format, TensorFlow's Graph format, the ONNX +abstraction, etc) as well as an abstraction of computation (a specific way of +modeling a convolution, a set of supported operations etc). + +MLIR uses a standard IR (i.e., a set of data structures) for representing these +computations - this allows a huge amount of shared infrastructure across these +problem domains. MLIR then allows the definition of domain-specific "dialects" +that describe the set of operations that are legal and supported for a given +application. This means that the actual translations between data structures are +kept as simple as possible - and are thus relatively easy to make "correct". +This allows the common compiler infrastructure to handle the mapping problems +and the other issues within the domain. + +MLIR's design is directly informed by the experience of building (and then +living with) intermediate representations like the LLVM IR, LLVM SelectionDAG, +the LLVM machine instruction representation, Swift SIL IR, and learns new +lessons from TensorFlow and XLA HLO, as well as learning from building countless +research and production systems on top of them. Our goal is to drag the state of +the art in compilers forward, not to merely apply a few well-known techniques to +the machine learning domain. + +### What does adoption mean? + +The point of this document is not to advocate for rewriting any particular +subsystem in TensorFlow - indeed, the burden required to justify a rewrite is +high, and often very specific to that subsystem. That said, there are several +subsystems that are about to get rewritten or substantially revised anyway, so +we use those as examples to concretely describe the benefits that MLIR provides +in these cases and what it will take. The subsystems discussed are: + +1. the TF Lite TOCO translator, which we need to improve error + reporting/reliability issues and generalize it to support more ops, and +1. the TF/XLA bridge which needs to improve usability by merging some of its + usage models, support dynamic shapes and generalize guest subsystem support + to Tensor-RT and nGraph. +1. Grappler is another subsystem that is likely to get substantial revisions in + the future, and would definitely benefit from the MLIR framework, but there + are no known plans to do that work at this point, so we don't discuss it + further. + +Adopting MLIR for these works the same way - and, in fact, the work to support +TF Lite is mostly a subset of the larger work to support the functionality of +the TF/XLA bridge. TF Lite and the TF/XLA bridge include several compiler passes +(things like encapsulate, functionalize control flow, lowering of ops, fusion, +constant folding, shape inference, etc). + +MLIR supports converting from TensorFlow Graphs to MLIR and back, which means +that we can start by putting in a no-op translation to MLIR and back into the +pipeline, and verify that nothing breaks. Then we can work on replacing the +compiler transformations one by one by reimplementing them (with the improved +algorithms that we're planning). + +This is a development plan, we wouldn't actually ship a TensorFlow that just +uses MLIR for a single pass. In practice, we'll have the MLIR flag gated under +an option, build out a replacement for an entire subsystem (e.g. the TOCO +translator) and when the time is right, we'll do A/B comparisons and eventually +make a switch and phase out the old code over time. + +## What benefit does MLIR provide? + +The adoption plan above might sound like it only makes things worse in the +immediate term - we have two implementations of the same functionality, we are +dividing our efforts, etc. In order for this to be worth it, we should have a +good sense that we are building towards an improved future that will make +customers and TensorFlow engineers happier when it lands. Here we describe a few +of the benefits that MLIR provides, in no particular order: + +### A Lossless Human Editable Textual Representation + +The MLIR in-memory data structure has a human readable and writable format, as +well as [a specification](LangRef.md) for that format - built just like any +other programming language. Important properties of this format are that it is +compact, easy to read, and lossless. You can dump an MLIR program out to disk +and munge around with it, then send it through a few more passes. + +If you haven't worked with a system that works this way, it is hard to overstate +how big of a deal this in practice: it means that you can call `foo->dump()` on +an IR object to see its full contents, it means you can diff the IR before and +after a change, delta reduce IR files, and many other things. + +### A Graph Verification Pass + +Like many other popular compiler infrastructures, MLIR provides infrastructure +and implementation for a "verifier" which checks that the IR is well formed. The +MLIR verifier is a simple framework that makes it easy to provide a single +source of truth for those correctness properties and is general across all +Dialects (e.g. TF Graph, TF Lite flat buffer, XLA HLO, etc). + +A verifier pass is sort of like a 'super assertion' that catches mistakes in +program transformations early, making you as an engineer more productive, making +the product more reliable, and making it easier to track down bugs when they +appear - because the verifier can be run at any time, either as a compiler pass +or with a single function call. + +While MLIR provides a well-considered infrastructure for IR verification, and +has simple checks for existing TensorFlow operations, there is a lot that should +be added here and lots of opportunity to get involved! + +### Designed for Testability + +There are many aspects of this in MLIR, but we'll focus on compiler +transformations since they are the easiest to understand. Compiler +transformations are modeled as subclasses of the `Pass` C++ class, which are +driven by an `mlir-opt` tool. When combined with a lossless textual +representation, it becomes really easy to write unit tests for compiler +transformations, for example, this is a simple test that shows "x-x" is being +turned into zero: + +```mlir + // RUN: mlir-opt %s -canonicalize | FileCheck %s + func @test_subi_zero_cfg(%arg0: i32) -> i32 { + %y = subi %arg0, %arg0 : i32 + return %y: i32 + } + // CHECK-LABEL: func @test_subi_zero_cfg(%arg0: i32) + // CHECK-NEXT: %c0_i32 = constant 0 : i32 + // CHECK-NEXT: return %c0 +``` + +The "CHECK" comments are interpreted by the +[LLVM FileCheck tool](https://llvm.org/docs/CommandGuide/FileCheck.html), which +is sort of like a really advanced grep. This test is fully self-contained: it +feeds the input into the [canonicalize pass](Canonicalization.md), and checks +that the output matches the CHECK lines. See the `test/Transforms` directory for +more examples. In contrast, standard unit testing exposes the API of the +underlying framework to lots and lots of tests (making it harder to refactor and +move the API), typically requires a lot more code, and exacerbates issues with +link time. For examples, see +[the TEST_F functions in TensorFlow's testsuite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc). + +MLIR has been pervasively designed with this sort of design by testability, +allowing us to put in place a culture that expects every behavior changing +commit to include a test case, and for these test cases to be stable and +reliable over time, since they are testing exactly what they are supposed to. +End to end integration tests are still super useful for some things of course! + +### Infrastructure for Warnings and Error Diagnostics and Location Tracking + +MLIR benefits from the lessons learned from building other compilers - including +Clang which +[[set the standard](http://blog.llvm.org/2010/04/amazing-feats-of-clang-error-recovery.html)](http://blog.llvm.org/2010/04/amazing-feats-of-clang-error-recovery.html) +for quality of implementation in C/C++ compiler diagnostics. Drawing from this +experience (and fixing mistakes in LLVM), MLIR requires that operations and +functions carry abstract location information, that transformations propagate +this information, and provides standardized mechanisms to emit errors and +warnings, as well as for clients to hook into them to capture and report them in +custom ways. + +Why is this important? In practice, many graph-to-graph translators can fail +(e.g. TF Lite when an unsupported op is used) and it is important to be able to +report the error up through to the user in the most precise way possible, in +order for it to be actionable. This includes tracking rewrites through fusions +and fissions of ops, mapping back into language / API specific domains, etc. + +More selfishly for infrastructure hackers, this is a huge boon because it means +that it is easy to write good tests for this: the testing tools for MLIR capture +the diagnostics produced by passes (using the standard diagnostic hooks) and +check that they match the expected diagnostics in the testcase. For example, to +test the dependence analysis infra in the code generator, Andy Davis wrote a +simple pass that checks dependencies and emits them as "notes", allowing him to +write tests like this: + +```mlir + // RUN: mlir-opt %s -memref-dependence-check -verify-diagnostics + func @different_memrefs() { + %m.a = alloc() : memref<100xf32> + %m.b = alloc() : memref<100xf32> + %c0 = constant 0 : index + %c1 = constant 1.0 : f32 + store %c1, %m.a[%c0] : memref<100xf32> + // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + %v0 = load %m.b[%c0] : memref<100xf32> + return + } +``` + +Note that a major limitation of this is that MLIR suffers from a problem of +"garbage in, garbage out": if the input locations to MLIR are imprecise, then +there is nothing that it can do to recover them. There is work underway in +TensorFlow/Python to improve the situation, and Swift for TensorFlow already has +perfect location tracking due to its design. + +### Shape Information Captured in the IR + +In TensorFlow Graphs, each op takes and returns values using a very simple type +system (TF_DataType) in which each value is a tensor of unknown rank and +dimensions. At the same time, many graphs have static shapes easily knowable for +wide swaths of the computation, and even dynamically shaped operations often +have statically knowable dimensions. Many analyses and transformations benefit +and use this information when available, but because TensorFlow graphs don't +capture this (e.g. serialize it to proto), passes have to recompute it on demand +with ShapeRefiner. + +The [MLIR Tensor Type](LangRef.md#tensor-type) directly captures shape +information, so you can have things like: + +```mlir + %x = tf.Add %x, %y : tensor<128 x 8 x ? x f32> +``` + +Capturing this in the IR is expected to speed up transformations (avoiding +recomputing the same info over and over again) which therefore makes it +practical to apply stronger shape analysis algorithms. It also makes it easier +to work with the IR, because on-the-side representations can get out of date, +and the API is easier to work with from an ergonomics perspective. + +### Unified Graph Rewriting Infrastructure + +This is still a work in progress, but we have sightlines towards a +[general rewriting infrastructure](GenericDAGRewriter.md) for transforming DAG +tiles into other DAG tiles, using a declarative pattern format. DAG to DAG +rewriting is a generalized solution for many common compiler optimizations, +lowerings, and other rewrites and having an IR enables us to invest in building +a single high-quality implementation. + +Declarative pattern rules are preferable to imperative C++ code for a number of +reasons: they are more compact, easier to reason about, can have checkers +written against them, and new tools can be built that inspect and manipulate the +declarative patterns in interesting ways - e.g. applying theorem provers to +them. It will be exciting to see this ecosystem develop as the infrastructure +matures. + +### Clarified Semantics for TensorFlow Operations + +One of the challenging things about working with TensorFlow is that there are +many invariants and behaviors that need to be preserved and known about when +working with Graphs, and these can be difficult to reason about and lead to +bugs. Things like 'dead values', Switch and Merge nodes, concurrency semantics, +nodes that execute even when passed a dead value, multiple device program +representation - etc... all add complexities that can make it challenging to +reason about whether a transformation or analysis is correct in general. Even +something as simple as constant folding or transforming integer `x-x` into `0` +is non-trivial because you need to consider control dependence edges. + +One of our major goals for the TensorFlow dialect of MLIR is to sort out these +situations and upgrade existing TensorFlow graphs to semantics that are easier +to reason about. The solutions to these problems are all still being debated, +but those discussions have already yielded a lot of potential answers: +introducing a `tf_dead_or` types for switch/merge, modeling of TF operations +using futures/async semantics etc. None of these particular battles are critical +or important for MLIR to succeed (because of its "meta" nature, the abstraction +decisions of any given dialect are up for it to decide), but each one that works +out will make it easier to work with and transform TensorFlow operations. We +expect these issues to get nailed down in the next couple of months when MLIR +effort moves beyond TF Lite / TOCO support. The discussions that are happening +now are super valuable and making progress. + +### Ergonomics + +A minor-in-theory, but important-in-practice point is that MLIR is designed to +make it easy, memory efficient, and less error prone to transform code than +other systems. `TensorFlow::Graph` has implementation issues where the same +information is stored redundantly in different places (which must be manually +kept up to date), has somewhat unusual representation of certain constructs +(e.g. the function library, which makes it very difficult to add or remove +functions, e.g. during interprocedural transformations), and stores information +in the graph that is used by the executor, but isn't necessary for program +transformation. + +TensorFlow has made a lot of progress in this area over the years, and there are +lots of ideas about further improvements in the future, we are happy that MLIR +addresses these needs (making it much easier to implement correct program +transformations) today, and are committed to pushing hard to make it better. + +### Compile Time Performance and Memory Use + +MLIR has been designed to be memory and compile-time efficient in its algorithms +and data structures, using immutable and uniqued structures, low level +bit-packing, and other well-known techniques to avoid unnecessary heap +allocations, and allow simple and safe multithreaded optimization of MLIR +programs. There are other reasons to believe that the MLIR implementations of +common transformations will be more efficient than the Python and C++ +TensorFlow::Graph implementations of the same things, given the current +implementation details of TensorFlow. + +That said, this is very much a theory at this point. When the new implementation +of various subsystems are available, we will see what happens in practice: there +will be no reason to speculate - we can measure. + +## Common Questions and Concerns + +Here we address some frequently asked questions and concerns. + +### Isn't MLIR a big dependency to take on? + +We've heard that at least some people are concerned that MLIR is a "big" +dependency to take on, and could result in large code size. Here are some key +points MLIR: + +1. The entire MLIR codebase is a pretty small C++ code base in absolute terms + compared to what goes into a modern ML framework. +1. Like LLVM, MLIR is designed as a set of libraries that clients can link in + or ignore as they wish. For example, the transformations in MLIR kept + separate from the core IR abstractions, and dialect specific code (e.g. + TensorFlow, TF-Lite, XLA, etc) is all independently selectable by the build + system. Clients that don't care about XLA don't link in that code, whether + they are a TF-Lite system or a client that is completely unrelated to + TensorFlow. +1. MLIR's only third party dependency is on LLVM, but it doesn't depend on LLVM + IR or any other heavy dependency - it just depends on LLVM's support library + which provides efficient hash tables and other + [memory efficient data structures that the STL does not](http://llvm.org/docs/ProgrammersManual.html#picking-the-right-data-structure-for-a-task). + There have been discussions about splitting this set of libraries out to its + own subproject in LLVM that the LLVM IR project depends on. This would be + great for MLIR as well as other LLVM subprojects. +1. TensorFlow and many other frameworks already use LLVM - if so, MLIR would + not be pulling in an additional dependency at all. + +### How does MLIR represent {control flow, concurrency, …} semantics in TensorFlow? + +MLIR provides a dialect that is an isomorphic 1-1 mapping between TensorFlow +graphs and MLIR, as well as a pretty complete translator back and forth (the +only known gap is that a few TF_DataType enums aren't handled yet). MLIR is a +"Multi-Level IR", which allows it to represent code with different abstraction +levels, so the ability to faithfully represent TensorFlow code in a completely +backwards compatible way (even if there are some historical warts!) is critical. + +In *addition* to the isomorphic mapping, we are actively working on efforts to +raise the abstraction level for working with TensorFlow graphs in MLIR. Doing so +would make it even easier to write TensorFlow transformations than it is today, +and would provide a path to migrating TF 1.x graphs forward into the TF 2.x +world. For example, because MLIR has an extensible type system, we can directly +model whether it is impossible for a Tensor value to be a "dead" value - similar +to the use of optional types in modern programming languages. + +These discussions occasionally cause confusion because there are several issues +being mixed up into one: + +* What are the current semantics of TensorFlow graphs, and what invariants can + we rely on? +* What should the semantics be in TensorFlow 2.0? +* What do programs rely on in practice, and if it is unfriendly, can we + migrate it? +* Can we find a way to make it so transforms don't have to worry about the + complexities of Switch/Merge, by using higher level control flow + representations? (tentative answer: yes) +* How should MLIR represent async vs sync operations, what invariants are + provided, how does this dovetail with control flow? +* When is it safe and beneficial to perform optimizations that might reduce + parallelism? + +All of these questions have a "conservative/safe fallback": we can continue +providing exactly the same abstractions that TensorFlow always has. That said, +we are trying hard to level-up the representation (taking advantage of the +"Multi-Level" part of MLIR) because doing so will make it much much easier to +write analyses and transformations than it currently is in TensorFlow. + +### Non Goals + +It is important to point out things that MLIR does not aim to do. For example, +there is no runtime component to MLIR: the TensorFlow executor, the TF Lite +FlatBuffer interpreter, or other existing runtime should be used as-is. + +Another non-goal is that MLIR currently doesn't support a stable binary +encoding. We will certainly add this at some point, but existing formats should +be used for serialization and distribution in the meantime. diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md new file mode 100644 index 0000000000000000000000000000000000000000..ff3a21fa1bb6ec65cf444b74ddcbf57ea4b21f72 --- /dev/null +++ b/mlir/docs/OpDefinitions.md @@ -0,0 +1,1210 @@ +# Table-driven Operation Definition Specification (ODS) + +In addition to specializing the `mlir::Op` C++ template, MLIR also supports +defining operations in a table-driven manner. This is achieved via +[TableGen][TableGen], which is both a generic language and its tooling to +maintain records of domain-specific information. Facts regarding an operation +are specified concisely into a TableGen record, which will be expanded into an +equivalent `mlir::Op` C++ template specialization at compiler build time. + +This manual explains in detail all the available mechanisms for defining +operations in such a table-driven manner. It aims to be a specification instead +of a tutorial. Please refer to [Quickstart tutorial to adding MLIR graph +rewrite](QuickstartRewrites.md) for the latter. + +In addition to detailing each mechanism, this manual also tries to capture +best practices. They are rendered as quoted bullet points. + +## Motivation + +MLIR allows pluggable dialects, and dialects contain, among others, a list of +operations. This open and extensible ecosystem leads to the "stringly" type IR +problem, e.g., repetitive string comparisons during optimization and analysis +passes, unintuitive accessor methods (e.g., generic/error prone `getOperand(3)` +vs self-documenting `getStride()`) with more generic return types, verbose and +generic constructors without default arguments, verbose textual IR dump, and +so on. Furthermore, operation verification is: + +1. best case: a central string-to-verification-function map, +1. middle case: duplication of verification across the code base, or +1. worst case: no verification functions. + +The fix is to support defining ops in a table-driven manner. Then for each +dialect, we can have a central place that contains everything you need to know +about each op, including its constraints, custom assembly form, etc. This +description is also used to generate helper functions and classes to allow +building, verification, parsing, printing, analysis, and many more. + +## Benefits + +Compared to the C++ template, this table-driven approach has several benefits +including but not limited to: + +* **Single source of truth**: We strive to encode all facts regarding an + operation into the record, so that readers don't need to jump among code + snippets to fully understand an operation. +* **Removing boilerplate**: We can automatically generate + operand/attribute/result getter methods, operation build methods, operation + verify methods, and many more utilities from the record. This greatly reduces + the boilerplate needed for defining a new op. +* **Facilitating auto-generation**: The usage of these operation information + records are by no means limited to op definition itself. We can use them to + drive the auto-generation of many other components, like computation graph + serialization. + +## TableGen Syntax + +We use TableGen as the language for specifying operation information. TableGen +itself just provides syntax for writing records; the syntax and constructs +allowed in a TableGen file (typically with filename suffix `.td`) can be found +[here][TableGenIntro]. The formal language specification can be found +[here][TableGenRef]. _Roughly_ speaking, + +* TableGen `class` is similar to C++ class; it can be templated and + subclassed. +* TableGen `def` is similar to C++ object; it can be declared by specializing + a TableGen `class` (e.g., `def MyDef : MyClass<...>;`) or completely + independently (e.g., `def MyDef;`). It cannot be further templated or + subclassed. +* TableGen `dag` is a dedicated type for directed acyclic graph of elements. A + `dag` has one operator and zero or more arguments. Its syntax is `(operator + arg0, arg1, argN)`. The operator can be any TableGen `def`; an argument can + be anything, including `dag` itself. We can have names attached to both the + operator and the arguments like `(MyOp:$op_name MyArg:$arg_name)`. + +Please see the [language introduction][TableGenIntro] to learn about all the +types and expressions supported by TableGen. + +## Operation Definition + +MLIR defines several common constructs to help operation definition and provide +their semantics via a special [TableGen backend][TableGenBackend]: +[`OpDefinitionsGen`][OpDefinitionsGen]. These constructs are defined in +[`OpBase.td`][OpBase]. The main ones are + +* The `Op` class: It is the main construct for defining operations. All facts + regarding the operation are specified when specializing this class, with the + help of the following constructs. +* The `Dialect` class: Operations belonging to one logical group are placed in + the same dialect. The `Dialect` class contains dialect-level information. +* The `OpTrait` class hierarchy: They are used to specify special properties + and constraints of the operation, including whether the operation has side + effect or whether its output has the same shape as the input. +* The `ins`/`outs` marker: These are two special makers builtin to the + `OpDefinitionsGen` backend. They lead the definitions of operands/attributes + and results respectively. +* The `TypeConstraint` class hierarchy: They are used to specify the + constraints over operands or results. A notable subclass hierarchy is + `Type`, which stands for constraints for common C++ types. +* The `AttrConstraint` class hierarchy: They are used to specify the + constraints over attributes. A notable subclass hierarchy is `Attr`, which + stands for constraints for attributes whose values are of common types. + +An operation is defined by specializing the `Op` class with concrete contents +for all the fields it requires. For example, `tf.AvgPool` is defined as + +```tablegen +def TF_AvgPoolOp : TF_Op<"AvgPool", [NoSideEffect]> { + let summary = "Performs average pooling on the input."; + + let description = [{ +Each entry in `output` is the mean of the corresponding size `ksize` +window in `value`. + }]; + + let arguments = (ins + TF_FpTensor:$value, + + Confined]>:$ksize, + Confined]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedAttr:$data_format + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} +``` + +In the following we describe all the fields needed. Please see the definition +of the `Op` class for the complete list of fields supported. + +### Operation name + +The operation name is a unique identifier of the operation within MLIR, e.g., +`tf.Add` for addition operation in the TensorFlow dialect. This is the +equivalent of the mnemonic in assembly language. It is used for parsing and +printing in the textual format. It is also used for pattern matching in graph +rewrites. + +The full operation name is composed of the dialect name and the op name, with +the former provided via the dialect and the latter provided as the second +template parameter to the `Op` class. + +### Operation documentation + +This includes both an one-line `summary` and a longer human-readable +`description`. They will be used to drive automatic generation of dialect +documentation. They need to be provided in the operation's definition body: + +```tablegen +let summary = "..."; + +let description = [{ +... +}]; +``` + +`description` should be written in Markdown syntax. + +Placing the documentation at the beginning is recommended since +it helps in understanding the operation. + +> * Place documentation at the beginning of the operation definition +> * The summary should be short and concise. It should be a one-liner without +> trailing punctuation. Put expanded explanation in description. + +### Operation arguments + +There are two kinds of arguments: operands and attributes. Operands are runtime +values produced by other ops; while attributes are compile-time known constant +values, including two categories: + +1. Natural attributes: these attributes affect the behavior of the operations + (e.g., padding for convolution); +1. Derived attributes: these attributes are not needed to define the operation + but are instead derived from information of the operation. E.g., the output + shape of type. This is mostly used for convenience interface generation or + interaction with other frameworks/translation. + +Both operands and attributes are specified inside the `dag`-typed `arguments`, +led by `ins`: + +```tablegen +let arguments = (ins + :$, + ... + :$, + ... +); +``` + +Here `` is a TableGen `def` from the `TypeConstraint` class +hierarchy. Similarly, `` is a TableGen `def` from the +`AttrConstraint` class hierarchy. See [Constraints](#constraints) for more +information. + +There is no requirements on the relative order of operands and attributes; they +can mix freely. The relative order of operands themselves matters. From each +named argument a named getter will be generated that returns the argument with +the return type (in the case of attributes the return type will be +constructed from the storage type, while for operands it will be `Value`). Each +attribute's raw value (e.g., as stored) can also be accessed via generated +`Attr` getters for use in transformation passes where the more user +friendly return type is less suitable. + +All the arguments should be named to 1) provide documentation, 2) drive +auto-generation of getter methods, 3) provide a handle to reference for other +places like constraints. + +#### Variadic operands + +To declare a variadic operand, wrap the `TypeConstraint` for the operand with +`Variadic<...>`. + +Normally operations have no variadic operands or just one variadic operand. For +the latter case, it is easy to deduce which dynamic operands are for the static +variadic operand definition. But if an operation has more than one variadic +operands, it would be impossible to attribute dynamic operands to the +corresponding static variadic operand definitions without further information +from the operation. Therefore, the `SameVariadicOperandSize` trait is needed to +indicate that all variadic operands have the same number of dynamic values. + +#### Optional attributes + +To declare an optional attribute, wrap the `AttrConstraint` for the attribute +with `OptionalAttr<...>`. + +#### Attributes with default values + +To declare an attribute with a default value, wrap the `AttrConstraint` for the +attribute with `DefaultValuedAttr<..., "...">`. + +The second parameter to `DefaultValuedAttr` should be a string containing the +C++ default value. For example, a float default value should be specified as +like `"0.5f"`, and an integer array default value should be specified as like +`"{1, 2, 3}"`. + +#### Confining attributes + +`Confined` is provided as a general mechanism to help modelling further +constraints on attributes beyond the ones brought by value types. You can use +`Confined` to compose complex constraints out of more primitive ones. For +example, a 32-bit integer attribute whose minimum value must be 10 can be +expressed as `Confined]>`. + +Right now, the following primitive constraints are supported: + +* `IntMinValue`: Specifying an integer attribute to be greater than or + equal to `N` +* `IntMaxValue`: Specifying an integer attribute to be less than or equal + to `N` +* `ArrayMinCount`: Specifying an array attribute to have at least `N` + elements +* `IntArrayNthElemEq`: Specifying an integer array attribute's `I`-th + element to be equal to `N` +* `IntArrayNthElemMinValue`: Specifying an integer array attribute's + `I`-th element to be greater than or equal to `N` + +TODO: Design and implement more primitive constraints + +### Operation results + +Similar to operands, results are specified inside the `dag`-typed `results`, led +by `outs`: + +```tablegen +let results = (outs + :$, + ... +); +``` + +#### Variadic results + +Similar to variadic operands, `Variadic<...>` can also be used for results. +And similarly, `SameVariadicResultSize` for multiple variadic results in the +same operation. + +### Operation traits and constraints + +Traits are operation properties that affect syntax or semantics. MLIR C++ +models various traits in the `mlir::OpTrait` namespace. + +Both operation traits, [interfaces](#operation-interfaces), and constraints +involving multiple operands/attributes/results are provided as the second +template parameter to the `Op` class. They should be deriving from the `OpTrait` +class. See [Constraints](#constraints) for more information. + +### Operation interfaces + +[Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by +which to opaquely call methods and access information on an *Op instance*, +without knowing the exact operation type. Operation interfaces defined in C++ +can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside +from using pre-existing interfaces in the C++ API, the ODS framework also +provides a simplified mechanism for defining such interfaces; that removes much +of the boilerplate necessary. + +Providing a definition of the `OpInterface` class will auto-generate the C++ +classes for the interface. An `OpInterface` includes a name, for the C++ class, +a description, and a list of interface methods. + +```tablegen +def MyInterface : OpInterface<"MyInterface"> { + let description = ...; + let methods = [...]; +} +``` + +There are two types of methods that can be used with an interface, +`InterfaceMethod` and `StaticInterfaceMethod`. They are both comprised of the +same core components, with the distinction that `StaticInterfaceMethod` models a +static method on the derived operation. + +An `InterfaceMethod` is comprised of the following components: + +* Description + - A string description of what this method does and its invariants. +* ReturnType + - A string corresponding to the C++ return type of the method. +* MethodName + - A string corresponding to the desired name of the method. +* Arguments (Optional) + - A dag of strings that correspond to a C++ type and variable name + respectively. +* MethodBody (Optional) + - An optional explicit implementation of the interface method. + - `ConcreteOp` is an implicitly defined typename that can be used to refer + to the type of the derived operation currently being operated on. + - In non-static methods, a variable 'ConcreteOp op' is defined and may be + used to refer to an instance of the derived operation. +* DefaultImplementation (Optional) + - An optional explicit default implementation of the interface method. + - This method is placed within the `Trait` class that is attached to the + operation. As such, this method has the same characteristics as any + other [`Trait`](Traits.md) method. + - `ConcreteOp` is an implicitly defined typename that can be used to refer + to the type of the derived operation currently being operated on. + +ODS also allows generating the declarations for the `InterfaceMethod` of the op +if one specifies the interface with `DeclareOpInterfaceMethods` (see example +below). + +Examples: + +```tablegen +def MyInterface : OpInterface<"MyInterface"> { + let description = [{ + My interface is very interesting. ... + }]; + + let methods = [ + // A simple non-static method with no inputs. + InterfaceMethod<"'foo' is a non-static method with no inputs.", + "unsigned", "foo" + >, + + // A new non-static method accepting an input argument. + InterfaceMethod<"/*insert doc here*/", + "Value ", "bar", (ins "unsigned":$i) + >, + + // Query a static property of the derived operation. + StaticInterfaceMethod<"'fooStatic' is a static method with no inputs.", + "unsigned", "fooStatic" + >, + + // Provide the definition of a static interface method. + // Note: `ConcreteOp` corresponds to the derived operation typename. + StaticInterfaceMethod<"/*insert doc here*/", + "Operation *", "create", (ins "OpBuilder &":$builder, "Location":$loc), [{ + return builder.create(loc); + }]>, + + // Provide a definition of the non-static method. + // Note: `op` corresponds to the derived operation variable. + InterfaceMethod<"/*insert doc here*/", + "unsigned", "getNumInputsAndOutputs", (ins), [{ + return op.getNumInputs() + op.getNumOutputs(); + }]>, + + // Provide only a default definition of the method. + // Note: `ConcreteOp` corresponds to the derived operation typename. + InterfaceMethod<"/*insert doc here*/", + "unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{ + ConcreteOp op = cast(getOperation()); + return op.getNumInputs() + op.getNumOutputs(); + }]>, + ]; +} + +// Interfaces can optionally be wrapped inside DeclareOpInterfaceMethods. This +// would result in autogenerating declarations for members `foo`, `bar` and +// `fooStatic`. Methods with bodies are not declared inside the op +// declaration but instead handled by the op interface trait directly. +def OpWithInferTypeInterfaceOp : Op<... + [DeclareOpInterfaceMethods]> { ... } +``` + +### Builder methods + +For each operation, there are a few builders automatically generated based on +the arguments and returns types. For example, given the following op definition: + +```tablegen +def MyOp : ... { + let arguments = (ins + I32:$i32_operand, + F32:$f32_operand, + ..., + + I32Attr:$i32_attr, + F32Attr:$f32_attr, + ... + ); + + let results = (outs + I32:$i32_result, + F32:$f32_result, + ... + ); +} +``` + +The following builders are generated: + +```c++ +// All result-types/operands/attributes have one aggregate parameter. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ArrayRef resultTypes, + ValueRange operands, + ArrayRef attributes); + +// Each result-type/operand/attribute has a separate parameter. The parameters +// for attributes are of mlir::Attribute types. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + Type i32_result, Type f32_result, ..., + Value i32_operand, Value f32_operand, ..., + IntegerAttr i32_attr, FloatAttr f32_attr, ...); + +// Each result-type/operand/attribute has a separate parameter. The parameters +// for attributes are raw values unwrapped with mlir::Attribute instances. +// (Note that this builder will not always be generated. See the following +// explanation for more details.) +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + Type i32_result, Type f32_result, ..., + Value i32_operand, Value f32_operand, ..., + APInt i32_attr, StringRef f32_attr, ...); + +// Each operand/attribute has a separate parameter but result type is aggregate. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ArrayRef resultTypes, + Value i32_operand, Value f32_operand, ..., + IntegerAttr i32_attr, FloatAttr f32_attr, ...); + +// All operands/attributes have aggregate parameters. +// Generated if InferTypeOpInterface interface is specified. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ValueRange operands, + ArrayRef attributes); + +// (And manually specified builders depending on the specific op.) +``` + +The first form provides basic uniformity so that we can create ops using the +same form regardless of the exact op. This is particularly useful for +implementing declarative pattern rewrites. + +The second and third forms are good for use in manually written code given that +they provide better guarantee via signatures. + +The third form will be generated if any of the op's attribute has different +`Attr.returnType` from `Attr.storageType` and we know how to build an attribute +from an unwrapped value (i.e., `Attr.constBuilderCall` is defined.) +Additionally, for the third form, if an attribute appearing later in the +`arguments` list has a default value, the default value will be supplied in the +declaration. This works for `BoolAttr`, `StrAttr`, `EnumAttr` for now and the +list can grow in the future. So if possible, default valued attribute should be +placed at the end of the `arguments` list to leverage this feature. (This +behavior is essentially due to C++ function parameter default value placement +restrictions.) Otherwise, the builder of the third form will still be generated +but default values for the attributes not at the end of the `arguments` list +will not be supplied in the builder's signature. + +And there may potentially exist other builders depending on the specific op; +please refer to the +[generated C++ file](#run-mlir-tblgen-to-see-the-generated-content) for the +complete list. + +#### Custom builder methods + +However, if the above cases cannot satisfy all needs, you can define additional +convenience build methods with `OpBuilder`. + +`OpBuilder` is a class that takes the parameter list and the optional `build()` +method body. They are separated because we need to generate op declaration and +definition into separate files. The parameter list should _include_ `Builder +*builder, OperationState &state`. If the `body` is not provided, _only_ the +builder declaration will be generated; this provides a way to define complicated +builders entirely in C++ files. + +For example, for the following op: + +```tablegen +def MyOp : Op<"my_op", []> { + let arguments = (ins F32Attr:$attr); + + let results = (outs); +} +``` + +If we want to define a builder with a default value for the only attribute, we +can add into `MyOp`: + +```tablegen +def MyOp : ... { + ... + + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, float val = 0.5f", [{ + state.addAttribute("attr", builder->getF32FloatAttr(val)); + }]> + ]; +} +``` + +The generated builder will look like: + +```c++ +static void build(Builder *builder, OperationState &state, float val = 0.5f) { + state.addAttribute("attr", builder->getF32FloatAttr(val)); +} +``` + +### Custom parser and printer methods + +Functions to parse and print the operation's custom assembly form. + +### Custom verifier code + +Verification code will be automatically generated for +[constraints](#constraints) specified on various entities of the op. To +perform _additional_ verification, you can use + +```tablegen +let verifier = [{ + ... +}]; +``` + +Code placed in `verifier` will be called after the auto-generated verification +code. + +### `hasCanonicalizer` + +This boolean field indicate whether canonicalization patterns have been defined +for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should +be defined. + +### `hasFolder` + +This boolean field indicate whether general folding rules have been defined +for this operation. If it is `1`, then `::fold()` should be defined. + +### Extra declarations + +One of the goals of table-driven op definition is to auto-generate as much logic +and methods needed for each op as possible. With that said, there will always be +long-tail cases that won't be covered. For such cases, you can use +`extraClassDeclaration`. Code in `extraClassDeclaration` will be copied +literally to the generated C++ op class. + +Note that `extraClassDeclaration` is a mechanism intended for long-tail cases +by power users; for not-yet-implemented widely-applicable cases, improving the +infrastructure is preferable. + +### Generated C++ code + +[OpDefinitionsGen][OpDefinitionsGen] processes the op definition spec file and +generates two files containing the corresponding C++ code: one for declarations, +the other for definitions. The former is generated via the `-gen-op-decls` +command-line option, while the latter is via the `-gen-op-defs` option. + +The definition file contains all the op method definitions, which can be +included and enabled by defining `GET_OP_CLASSES`. For each operation, +OpDefinitionsGen generates an operation class and an +[operand adaptor](#operand-adaptors) class. Besides, it also contains a +comma-separated list of all defined ops, which can be included and enabled by +defining `GET_OP_LIST`. + +#### Class name and namespaces + +For each operation, its generated C++ class name is the symbol `def`ed with +TableGen with dialect prefix removed. The first `_` serves as the delimiter. +For example, for `def TF_AddOp`, the C++ class name would be `AddOp`. +We remove the `TF` prefix because it is for scoping ops; other dialects +may as well define their own `AddOp`s. + +The namespaces of the generated C++ class will come from the dialect's +`cppNamespace` field. For example, if a dialect's `cppNamespace` is `A::B`, +then an op of that dialect will be placed in +`namespace A { namespace B { ... } }`. If a dialect does not specify a +`cppNamespace`, we then use the dialect's name as the namespace. + +This means the qualified name of the generated C++ class does not necessarily +match exactly with the operation name as explained in +[Operation name](#operation-name). This is to allow flexible naming to satisfy +coding style requirements. + +#### Operand adaptors + +For each operation, we automatically generate an _operand adaptor_. This class +solves the problem of accessing operands provided as a list of `Value`s without +using "magic" constants. The operand adaptor takes a reference to an array of +`Value` and provides methods with the same names as those in the operation class +to access them. For example, for a binary arithmetic operation, it may provide +`.lhs()` to access the first operand and `.rhs()` to access the second operand. + +The operand adaptor class lives in the same namespace as the operation class, +and has the name of the operation followed by `OperandAdaptor`. A template +declaration `OperandAdaptor<>` is provided to look up the operand adaptor for +the given operation. + +Operand adaptors can be used in function templates that also process operations: + +```c++ +template +std::pair zip(BinaryOpTy &&op) { + return std::make_pair(op.lhs(), op.rhs());; +} + +void process(AddOp op, ArrayRef newOperands) { + zip(op); + zip(OperandAdaptor(newOperands)); + /*...*/ +} +``` + +## Constraints + +Constraint is a core concept in table-driven operation definition: operation +verification and graph operation matching are all based on satisfying +constraints. So both the operation definition and rewrite rules specification +significantly involve writing constraints. We have the `Constraint` class in +[`OpBase.td`][OpBase] has the common base class for all constraints. + +An operation's constraint can cover different range; it may + +* Only concern a single attribute (e.g. being an 32-bit integer greater than 5), +* Multiple operands and results (e.g., the 1st result's shape must be the same + as the 1st operand), or +* Intrinsic to the operation itself (e.g., having no side effect). + +We call them as single-entity constraint, multi-entity constraint, and traits, +respectively. + +### Single-entity constraint + +Constraints scoped to a single operand, attribute, or result are specified at +the entity's declaration place as described in +[Operation arguments](#operation-arguments) and +[Operation results](#operation-results). + +To help modelling constraints of common types, a set of `TypeConstraint`s are +created; they are the `Type` subclass hierarchy. It includes `F32` for the +constraints of being a float, `TensorOf<[F32]>` for the constraints of being +a float tensor, and so on. + +Similarly, a set of `AttrConstraint`s are created for helping modelling +constraints of common attribute kinds. They are the `Attr` subclass hierarchy. +It includes `F32Attr` for the constraints of being a float attribute, +`F32ArrayAttr` for the constraints of being a float array attribute, and so on. + +### Multi-entity constraint + +Constraints involving more than one operand/attribute/result are quite common +on operations, like the element type and shape relation between operands and +results. These constraints should be specified as the `Op` class template +parameter as described in +[Operation traits and constraints](#operation-traits-and-constraints). + +Multi-entity constraints are modeled as `PredOpTrait` (a subclass of `OpTrait`) +in [`OpBase.td`][OpBase].A bunch of constraint primitives are provided to help +specification. See [`OpBase.td`][OpBase] for the complete list. + +### Trait + +Traits are intrinsic properties of the operation like having side effect or not, +commutative or not, whether is a terminator, etc. These constraints should be +specified as the `Op` class template parameter as described in +[Operation traits and constraints](#operation-traits-and-constraints). + +Traits are modeled as `NativeOpTrait` (a subclass of `OpTrait`) in +[`OpBase.td`][OpBase]. They are backed and will be translated into the +corresponding C++ `mlir::OpTrait` classes. + +### How to specify new constraint + +To write a constraint, you need to provide its predicates and give it a +descriptive name. Predicates, modeled with the `Pred` class, are the workhorse +for composing constraints. The predicate for a constraint is typically built up +in a nested manner, using the two categories of predicates: + +1. `CPred`: the primitive leaf predicate. +2. Compound predicate: a predicate composed from child predicates using + predicate combiners (conjunction: `And`, disjunction: `Or`, negation: `Neg`, + substitution: `SubstLeaves`, concatenation: `Concat`). + +`CPred` is the basis for composing more complex predicates. It is the "atom" +predicate from the perspective of TableGen and the "interface" between +TableGen and C++. What is inside is already C++ code, which will be treated +as opaque strings with special placeholders to be substituted. + +You can put any C++ code that returns a boolean value inside a `CPred`, +including evaluating expressions, calling functions, calling class methods, +and so on. + +To help interaction with the C++ environment, there are a few special +placeholders provided to refer to entities in the context where this predicate +is used. They serve as "hooks" to the enclosing environment. This includes +`$_builder`, `$_op`, and `$_self`: + +* `$_builder` will be replaced by a `mlir::Builder` instance so that you can + access common build methods. +* `$_op` will be replaced by the current operation so that you can access + information of the current operation. +* `$_self` will be replaced with the entity this predicate is attached to. + E.g., `BoolAttr` is an attribute constraint that wraps a + `CPred<"$_self.isa()">`. Then for `F32:$attr`,`$_self` will be + replaced by `$attr`. For type constraints, it's a little bit special since + we want the constraints on each type definition reads naturally and we want + to attach type constraints directly to an operand/result, `$_self` will be + replaced by the operand/result's type. E.g., for `F32` in `F32:$operand`, its + `$_self` will be expanded as `getOperand(...)->getType()`. + +TODO(b/130663252): Reconsider the leading symbol for special placeholders. +Eventually we want to allow referencing operand/result $-names; such $-names +can start with underscore. + +For example, to write an attribute `attr` is an `IntegerAttr`, in C++ you can +just call `attr.isa()`. The code can be wrapped in a `CPred` as +`$_self.isa()`, with `$_self` as the special placeholder to be +replaced by the current attribute `attr` at expansion time. + +For more complicated predicates, you can wrap it in a single `CPred`, or you +can use predicate combiners to combine them. For example, to write the +constraint that an attribute `attr` is a 32-bit or 64-bit integer, you can +write it as + +```tablegen +And<[ + CPred<"$_self.isa()">, + Or<[ + CPred<"$_self.cast().getType().isInteger(32)">, + CPred<"$_self.cast().getType().isInteger(64)"> + ]> +]> +``` + +(Note that the above is just to show with a familiar example how you can use +`CPred` and predicate combiners to write complicated predicates. For integer +attributes specifically, [`OpBase.td`][OpBase] already defines `I32Attr` and +`I64Attr`. So you can actually reuse them to write it as `Or<[I32Attr.predicate, +I64Attr.predicate]>`.) + +TODO: Build up a library of reusable primitive constraints + +If the predicate is very complex to write with `CPred` together with predicate +combiners, you can also write it as a normal C++ function and use the `CPred` +as a way to "invoke" the function. For example, to verify an attribute `attr` +has some property, you can write a C++ function like + +```cpp +bool HasSomeProperty(Attribute attr) { ... } +``` + +and then define the op as: + +```tablegen +def HasSomeProperty : AttrConstraint, + "has some property">; + +def MyOp : Op<...> { + let arguments = (ins + ... + HasSomeProperty:$attr + ); +} +``` + +As to whether we should define the predicate using a single `CPred` wrapping +the whole expression, multiple `CPred`s with predicate combiners, or a single +`CPred` "invoking" a function, there are no clear-cut criteria. Defining using +`CPred` and predicate combiners is preferable since it exposes more information +(instead hiding all the logic behind a C++ function) into the op definition spec +so that it can potentially drive more auto-generation cases. But it will +require a nice library of common predicates as the building blocks to avoid the +duplication, which is being worked on right now. + +## Attribute Definition + +### Enum attributes + +Some attributes can only take values from an predefined enum, e.g., the +comparison kind of a comparison op. To define such attributes, ODS provides +several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`. + +* `StrEnumAttr`: each enum case is a string, the attribute is stored as a + [`StringAttr`][StringAttr] in the op. +* `IntEnumAttr`: each enum case is an integer, the attribute is stored as a + [`IntegerAttr`][IntegerAttr] in the op. +* `BitEnumAttr`: each enum case is a bit, the attribute is stored as a + [`IntegerAttr`][IntegerAttr] in the op. + +All these `*EnumAttr` attributes require fully specifying all of the allowed +cases via their corresponding `*EnumAttrCase`. With this, ODS is able to +generate additional verification to only accept allowed cases. To facilitate the +interaction between `*EnumAttr`s and their C++ consumers, the +[`EnumsGen`][EnumsGen] TableGen backend can generate a few common utilities: a +C++ enum class, `llvm::DenseMapInfo` for the enum class, conversion functions +from/to strings. This is controlled via the `-gen-enum-decls` and +`-gen-enum-defs` command-line options of `mlir-tblgen`. + +For example, given the following `EnumAttr`: + +```tablegen +def Case15: I32EnumAttrCase<"Case15", 15>; +def Case20: I32EnumAttrCase<"Case20", 20>; + +def MyIntEnum: I32EnumAttr<"MyIntEnum", "An example int enum", + [Case15, Case20]> { + let cppNamespace = "Outer::Inner"; + let stringToSymbolFnName = "ConvertToEnum"; + let symbolToStringFnName = "ConvertToString"; +} +``` + +The following will be generated via `mlir-tblgen -gen-enum-decls`: + +```c++ +namespace Outer { +namespace Inner { +// An example int enum +enum class MyIntEnum : uint32_t { + Case15 = 15, + Case20 = 20, +}; + +llvm::Optional symbolizeMyIntEnum(uint32_t); +llvm::StringRef ConvertToString(MyIntEnum); +llvm::Optional ConvertToEnum(llvm::StringRef); +inline constexpr unsigned getMaxEnumValForMyIntEnum() { + return 20; +} + +} // namespace Inner +} // namespace Outer + +namespace llvm { +template<> struct DenseMapInfo { + using StorageInfo = llvm::DenseMapInfo; + + static inline Outer::Inner::MyIntEnum getEmptyKey() { + return static_cast(StorageInfo::getEmptyKey()); + } + + static inline Outer::Inner::MyIntEnum getTombstoneKey() { + return static_cast(StorageInfo::getTombstoneKey()); + } + + static unsigned getHashValue(const Outer::Inner::MyIntEnum &val) { + return StorageInfo::getHashValue(static_cast(val)); + } + + static bool isEqual(const Outer::Inner::MyIntEnum &lhs, const Outer::Inner::MyIntEnum &rhs) { + return lhs == rhs; + } +}; +} +``` + +The following will be generated via `mlir-tblgen -gen-enum-defs`: + +```c++ +namespace Outer { +namespace Inner { +llvm::StringRef ConvertToString(MyIntEnum val) { + switch (val) { + case MyIntEnum::Case15: return "Case15"; + case MyIntEnum::Case20: return "Case20"; + } + return ""; +} + +llvm::Optional ConvertToEnum(llvm::StringRef str) { + return llvm::StringSwitch>(str) + .Case("Case15", MyIntEnum::Case15) + .Case("Case20", MyIntEnum::Case20) + .Default(llvm::None); +} +llvm::Optional symbolizeMyIntEnum(uint32_t value) { + switch (value) { + case 15: return MyIntEnum::Case15; + case 20: return MyIntEnum::Case20; + default: return llvm::None; + } +} + +} // namespace Inner +} // namespace Outer +``` + +Similarly for the following `BitEnumAttr` definition: + +```tablegen +def None: BitEnumAttrCase<"None", 0x0000>; +def Bit1: BitEnumAttrCase<"Bit1", 0x0001>; +def Bit2: BitEnumAttrCase<"Bit2", 0x0002>; +def Bit3: BitEnumAttrCase<"Bit3", 0x0004>; + +def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum", + [None, Bit1, Bit2, Bit3]>; +``` + +We can have: + +```c++ +// An example bit enum +enum class MyBitEnum : uint32_t { + None = 0, + Bit1 = 1, + Bit2 = 2, + Bit3 = 4, +}; + +llvm::Optional symbolizeMyBitEnum(uint32_t); +std::string stringifyMyBitEnum(MyBitEnum); +llvm::Optional symbolizeMyBitEnum(llvm::StringRef); +inline MyBitEnum operator|(MyBitEnum lhs, MyBitEnum rhs) { + return static_cast(static_cast(lhs) | static_cast(rhs)); +} +inline MyBitEnum operator&(MyBitEnum lhs, MyBitEnum rhs) { + return static_cast(static_cast(lhs) & static_cast(rhs)); +} +inline bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) { + return (static_cast(bits) & static_cast(bit)) != 0; +} + +namespace llvm { +template<> struct DenseMapInfo<::MyBitEnum> { + using StorageInfo = llvm::DenseMapInfo; + + static inline ::MyBitEnum getEmptyKey() { + return static_cast<::MyBitEnum>(StorageInfo::getEmptyKey()); + } + + static inline ::MyBitEnum getTombstoneKey() { + return static_cast<::MyBitEnum>(StorageInfo::getTombstoneKey()); + } + + static unsigned getHashValue(const ::MyBitEnum &val) { + return StorageInfo::getHashValue(static_cast(val)); + } + + static bool isEqual(const ::MyBitEnum &lhs, const ::MyBitEnum &rhs) { + return lhs == rhs; + } +}; +``` + +```c++ +std::string stringifyMyBitEnum(MyBitEnum symbol) { + auto val = static_cast(symbol); + // Special case for all bits unset. + if (val == 0) return "None"; + + llvm::SmallVector strs; + if (1u & val) { strs.push_back("Bit1"); val &= ~1u; } + if (2u & val) { strs.push_back("Bit2"); val &= ~2u; } + if (4u & val) { strs.push_back("Bit3"); val &= ~4u; } + + if (val) return ""; + return llvm::join(strs, "|"); +} + +llvm::Optional symbolizeMyBitEnum(llvm::StringRef str) { + // Special case for all bits unset. + if (str == "None") return MyBitEnum::None; + + llvm::SmallVector symbols; + str.split(symbols, "|"); + + uint32_t val = 0; + for (auto symbol : symbols) { + auto bit = llvm::StringSwitch>(symbol) + .Case("Bit1", 1) + .Case("Bit2", 2) + .Case("Bit3", 4) + .Default(llvm::None); + if (bit) { val |= *bit; } else { return llvm::None; } + } + return static_cast(val); +} + +llvm::Optional symbolizeMyBitEnum(uint32_t value) { + // Special case for all bits unset. + if (value == 0) return MyBitEnum::None; + + if (value & ~(1u | 2u | 4u)) return llvm::None; + return static_cast(value); +} +``` + +TODO(b/132506080): This following is outdated. Update it. + +An attribute is a compile time known constant of an operation. Attributes are +required to be known to construct an operation (e.g., the padding behavior is +required to fully define the `conv2d` op). + +Attributes are defined as having a storage type (corresponding to a derived +class of `mlir::Attribute`), a return type (that corresponds to the C++ type to +use in the generation of the helper accessors) as well as method to convert +between the internal storage and the helper method. Derived attributes are a +special class of attributes that do not have storage but are instead calculated +based on the operation and its attributes. + +## Debugging Tips + +### Run `mlir-tblgen` to see the generated content + +TableGen syntax sometimes can be obscure; reading the generated content can be +a very helpful way to understand and debug issues. To build `mlir-tblgen`, run +`cmake --build . --target mlir-tblgen` in your build directory and find the +`mlir-tblgen` binary in the `bin/` subdirectory. All the supported generators +can be found via `mlir-tblgen --help`. For example, `--gen-op-decls` and +`--gen-op-defs` as explained in [Generated C++ code](#generated-c++-code). + +To see the generated code, invoke `mlir-tblgen` with a specific generator by +providing include paths via `-I`. For example, + +```sh +# To see op C++ class declaration +mlir-tblgen --gen-op-decls -I /path/to/mlir/include /path/to/input/td/file +# To see op C++ class definition +mlir-tblgen --gen-op-defs -I /path/to/mlir/include /path/to/input/td/file +# To see op documentation +mlir-tblgen --gen-op-doc -I /path/to/mlir/include /path/to/input/td/file + +# To see op interface C++ class declaration +mlir-tblgen --gen-op-interface-decls -I /path/to/mlir/include /path/to/input/td/file +# To see op interface C++ class definition +mlir-tblgen --gen-op-interface-defs -I /path/to/mlir/include /path/to/input/td/file +# To see op interface documentation +mlir-tblgen --gen-op-interface-doc -I /path/to/mlir/include /path/to/input/td/file +``` + + +## Appendix + +### Requirements and existing mechanisms analysis + +The op description should as declarative as possible to allow a wide range of +tools to work with them and query methods generated from them. In particular +this means specifying traits, constraints and shape inference information in +a way that is easily analyzable (e.g., avoid opaque calls to C++ functions where +possible). + +We considered the approaches of several contemporary systems and focused on +requirements that were desirable: + +* Ops registered using a registry separate from C++ code. + * Unknown ops are allowed in MLIR, so ops need not be registered. The + ability of the compiler to optimize those ops or graphs containing those + ops is constrained but correct. + * The current proposal does not include a runtime op description, but it + does not preclude such description, it can be added later. + * The op registry is essential for generating C++ classes that make + manipulating ops, verifying correct construction etc. in C++ easier by + providing a typed representation and accessors. +* The op registry will be defined in + [TableGen](https://llvm.org/docs/TableGen/index.html) and be used to + generate C++ classes and utility functions + (builder/verifier/parser/printer). + * TableGen is a modelling specification language used by LLVM's backends + and fits in well with trait-based modelling. This is an implementation + decision and there are alternative ways of doing this. But the + specification language is good for the requirements of modelling the + traits (as seen from usage in LLVM processor backend modelling) and easy + to extend, so a practical choice. If another good option comes up, we + will consider it. +* MLIR allows both defined and undefined ops. + * Defined ops should have fixed semantics and could have a corresponding + reference implementation defined using, for example, EDSC. + * Dialects are under full control of the dialect owner and normally live + with the framework of the dialect. +* The op's traits (e.g., commutative) are modelled along with the op in the + registry. +* The op's operand/return type constraints are modelled along with the op in + the registry (see [Shape inference](#shape-inference) discussion below), + this allows (e.g.) optimized concise syntax in textual dumps. +* Behavior of the op is documented along with the op with a summary and a + description. The description is written in markdown and extracted for + inclusion in the generated LangRef section of the dialect. +* The generic assembly form of printing and parsing is available as normal, + but a custom parser and printer can either be specified or automatically + generated from an optional string representation showing the mapping of the + "assembly" string to operands/type. + * Parser-level remappings (e.g., `eq` to enum) will be supported as part + of the parser generation. +* Matching patterns are specified separately from the op description. + * Contrasted with LLVM there is no "base" set of ops that every backend + needs to be aware of. Instead there are many different dialects and the + transformations/legalizations between these dialects form a graph of + transformations. +* Reference implementation may be provided along with the op definition. + + * The reference implementation may be in terms of either standard ops or + other reference implementations. + + TODO: document expectation if the dependent op's definition changes. + +### A proposal for auto-generating printer and parser methods + +NOTE: Auto-generating printing/parsing (as explained in the below) has _not_ +been prototyped, and potentially just being able to specify custom printer/ +parser methods are sufficient. This should presumably be influenced by the +design of the assembler/disassembler logic that LLVM backends get for free +for machine instructions. + +The custom assembly form of the operation is specified using a string with +matching operation name, operands and attributes. With the ability +to express additional information that needs to be parsed to build the +operation: + +```tablegen +tfl.add $lhs, $rhs {fused_activation_function: $fused_activation_function}: ${type(self)} +``` + +1. The output is never shown in the "mnemonics" string as that is fixed form + and cannot be altered. +1. Custom parsing of ops may include some punctuation (e.g., parenthesis). +1. The operands/results are added to the created operation in the order that + they are shown in the input and output dags. +1. The `${type(self)}` operator is used to represent the type of the operator. + The type of operands can also be queried. +1. Attributes names are matched to the placeholders in the mnemonic strings. + E.g., attribute axis is matched with `$axis`. Custom parsing for attribute + type can be defined along with the attribute definition. +1. The information in the custom assembly form should be sufficient to invoke + the builder generated. That may require being able to propagate information + (e.g., the `$lhs` has the same type as the result). + +Printing is effectively the inverse of the parsing function generated with the +mnemonic string serving as a template. + +### Shape inference + +Type constraints are along (at least) three axis: 1) elemental type, 2) rank +(including static or dynamic), 3) dimensions. While some ops have no compile +time fixed shape (e.g., output shape is dictated by data) we could still have +some knowledge of constraints/bounds in the system for that op (e.g., the output +of a `tf.where` is at most the size of the input data). And so there are +additional valuable constraints that could be captured even without full +knowledge. + +Initially the shape inference will be declaratively specified using: + +* Constraint on the operands of an operation directly. For example + constraining the input type to be tensor/vector elements or that the + elemental type be of a specific type (e.g., output of sign is of elemental + type `i1`) or class (e.g., float like). +* Constraints across operands and results of an operation. For example, + enabling specifying equality constraints on type/constituents of a type + (shape and elemental type) between operands and results (e.g., the output + type of an add is the same as those of the input operands). + +In general there is an input/output transfer function which maps the inputs to +the outputs (e.g., given input X and Y [or slices thereof] with these sizes, the +output is Z [or this slice thereof]). Such a function could be used to determine +the output type (shape) for given input type (shape). + +But shape functions are determined by attributes and could be arbitrarily +complicated with a wide-range of specification possibilities. Equality +relationships are common (e.g., the elemental type of the output matches the +primitive type of the inputs, both inputs have exactly the same type [primitive +type and shape]) and so these should be easy to specify. Algebraic relationships +would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0 +is `[n+n, m]` matrix), while some ops only have defined shapes under certain +cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if +`b == c`). As ops are also verified, the shape inference need only specify rules +for the allowed cases (e.g., shape inference for matmul can ignore the case +where `b != c`), which would simplify type constraint specification. + +Instead of specifying an additional mechanism to specify a shape transfer +function, the reference implementation of the operation will be used to derive +the shape function. The reference implementation is general and can support the +arbitrary computations needed to specify output shapes. + +[TableGen]: https://llvm.org/docs/TableGen/index.html +[TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html +[TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html +[TableGenBackend]: https://llvm.org/docs/TableGen/BackEnds.html#introduction +[OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td +[OpDefinitionsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/OpDefinitionsGen.cpp +[EnumsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/EnumsGen.cpp +[StringAttr]: https://github.com/tensorflow/mlir/blob/master/g3doc/LangRef.md#string-attribute +[IntegerAttr]: https://github.com/tensorflow/mlir/blob/master/g3doc/LangRef.md#integer-attribute diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md new file mode 100644 index 0000000000000000000000000000000000000000..78ea257b57bc336d65cab8197898639f5ac20cd4 --- /dev/null +++ b/mlir/docs/Passes.md @@ -0,0 +1,298 @@ +# MLIR Passes + +This document describes the available MLIR passes and their contracts. + +[TOC] + +## Affine control lowering (`-lower-affine`) + +Convert operations related to affine control into a graph of blocks using +operations from the standard dialect. + +Loop statements are converted to a subgraph of blocks (initialization, condition +checking, subgraph of body blocks) with loop induction variable being passed as +the block argument of the condition checking block. Conditional statements are +converted to a subgraph of blocks (chain of condition checking with +short-circuit logic, subgraphs of 'then' and 'else' body blocks). `affine.apply` +operations are converted into sequences of primitive arithmetic operations that +have the same effect, using operands of the `index` type. Consequently, named +maps and sets may be removed from the module. + +For example, `%r = affine.apply (d0, d1)[s0] -> (d0 + 2*d1 + s0)(%d0, %d1)[%s0]` +can be converted into: + +```mlir +%d0 = <...> +%d1 = <...> +%s0 = <...> +%0 = constant 2 : index +%1 = muli %0, %d1 +%2 = addi %d0, %1 +%r = addi %2, %s0 +``` + +### Input invariant + +- no `Tensor` types; + +These restrictions may be lifted in the future. + +### Output IR + +Functions with `affine.for` and `affine.if` operations eliminated. These +functions may contain operations from the Standard dialect in addition to those +already present before the pass. + +### Invariants + +- Functions without a body are not modified. +- The semantics of the other functions is preserved. +- Individual operations other than those mentioned above are not modified if + they do not depend on the loop iterator value or on the result of + `affine.apply`. + +## Conversion from Standard to LLVM IR dialect (`-convert-std-to-llvm`) + +Convert standard operations into the LLVM IR dialect operations. + +### Input invariant + +- operations including: arithmetic on integers and floats, constants, direct + calls, returns and branches; +- no `tensor` types; +- all `vector` are one-dimensional; +- all blocks are reachable by following the successors of the first basic + block; + +If other operations are present and their results are required by the LLVM IR +dialect operations, the pass will fail. Any LLVM IR operations or types already +present in the IR will be kept as is. + +### Output IR + +Functions converted to LLVM IR. Function arguments types are converted +one-to-one. Function results are converted one-to-one and, in case more than 1 +value is returned, packed into an LLVM IR struct type. Function calls and +returns are updated accordingly. Block argument types are updated to use LLVM IR +types. + +## Data Copy DMA generation (`-affine-data-copy-generate`) + +Replaces all loads and stores on memref's living in 'slowMemorySpace' by +introducing DMA operations (strided DMA if necessary) to transfer data to/from +`fastMemorySpace` and rewriting the original load's/store's to instead +load/store from the allocated fast memory buffers. Additional options specify +the identifier corresponding to the fast memory space and the amount of fast +memory space available. The pass traverses through the nesting structure, +recursing to inner levels if necessary to determine at what depth DMA transfers +need to be placed so that the allocated buffers fit within the memory capacity +provided. If this is not possible (for example, when the elemental type itself +is of size larger than the DMA capacity), an error with location information is +emitted. The DMA transfers are also hoisted up past all loops with respect to +which the transfers are invariant. + +Input + +```mlir +func @loop_nest_tiled() -> memref<256x1024xf32> { + %0 = alloc() : memref<256x1024xf32> + affine.for %i0 = 0 to 256 step 32 { + affine.for %i1 = 0 to 1024 step 32 { + affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { + affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { + %1 = affine.load %0[%i2, %i3] : memref<256x1024xf32> + } + } + } + } + return %0 : memref<256x1024xf32> +} +``` + +Output (with flags: -affine-data-copy-generate -affine-data-copy-generate-fast-mem-space=2) + +```mlir +module { + func @loop_nest_tiled() -> memref<256x1024xf32> { + %c262144 = constant 262144 : index + %c0 = constant 0 : index + %0 = alloc() : memref<256x1024xf32> + %1 = alloc() : memref<256x1024xf32, 2> + %2 = alloc() : memref<1xi32> + affine.dma_start %0[%c0, %c0], %1[%c0, %c0], %2[%c0], %c262144 : memref<256x1024xf32>, memref<256x1024xf32, 2>, memref<1xi32> + affine.dma_wait %2[%c0], %c262144 : memref<1xi32> + affine.for %arg0 = 0 to 256 step 32 { + affine.for %arg1 = 0 to 1024 step 32 { + affine.for %arg2 = #map1(%arg0) to #map2(%arg0) { + affine.for %arg3 = #map1(%arg1) to #map2(%arg1) { + %3 = affine.load %1[%arg2, %arg3] : memref<256x1024xf32, 2> + } + } + } + } + dealloc %2 : memref<1xi32> + dealloc %1 : memref<256x1024xf32, 2> + return %0 : memref<256x1024xf32> + } +} +``` + +## Loop tiling (`-affine-loop-tile`) + +Performs tiling or blocking of loop nests. It currently works on perfect loop +nests. + +## Loop unroll (`-affine-loop-unroll`) + +This pass implements loop unrolling. It is able to unroll loops with arbitrary +bounds, and generate a cleanup loop when necessary. + +## Loop unroll and jam (`-affine-loop-unroll-jam`) + +This pass implements unroll and jam for loops. It works on both perfect or +imperfect loop nests. + +## Loop fusion (`-affine-loop-fusion`) + +Performs fusion of loop nests using a slicing-based approach. The fused loop +nests, when possible, are rewritten to access significantly smaller local +buffers instead of the original memref's, and the latter are often +either completely optimized away or contracted. This transformation leads to +enhanced locality and lower memory footprint through the elimination or +contraction of temporaries / intermediate memref's. These benefits are sometimes +achieved at the expense of redundant computation through a cost model that +evaluates available choices such as the depth at which a source slice should be +materialized in the designation slice. + +## Memref bound checking (`-memref-bound-check`) + +Checks all load's and store's on memref's for out of bound accesses, and reports +any out of bound accesses (both overrun and underrun) with location information. + +```mlir +test/Transforms/memref-bound-check.mlir:19:13: error: 'load' op memref out of upper bound access along dimension #2 + %x = load %A[%idx0, %idx1] : memref<9 x 9 x i32> + ^ +test/Transforms/memref-bound-check.mlir:19:13: error: 'load' op memref out of lower bound access along dimension #2 + %x = load %A[%idx0, %idx1] : memref<9 x 9 x i32> + ^ +``` + +## Memref dataflow optimization (`-memref-dataflow-opt`) + +This pass performs store to load forwarding for memref's to eliminate memory +accesses and potentially the entire memref if all its accesses are forwarded. + +Input + +```mlir +func @store_load_affine_apply() -> memref<10x10xf32> { + %cf7 = constant 7.0 : f32 + %m = alloc() : memref<10x10xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.store %cf7, %m[%i0, %i1] : memref<10x10xf32> + %v0 = affine.load %m[%i0, %i1] : memref<10x10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + return %m : memref<10x10xf32> +} +``` + +Output + +```mlir +module { + func @store_load_affine_apply() -> memref<10x10xf32> { + %cst = constant 7.000000e+00 : f32 + %0 = alloc() : memref<10x10xf32> + affine.for %arg0 = 0 to 10 { + affine.for %arg1 = 0 to 10 { + affine.store %cst, %0[%arg0, %arg1] : memref<10x10xf32> + %1 = addf %cst, %cst : f32 + } + } + return %0 : memref<10x10xf32> + } +} + +``` + +## Memref dependence analysis (`-memref-dependence-check`) + +This pass performs dependence analysis to determine dependences between pairs of +memory operations (load's and store's) on memref's. Dependence analysis exploits +polyhedral information available (affine maps, expressions, and affine.apply +operations) to precisely represent dependences using affine constraints, while +also computing dependence vectors from them, where each component of the +dependence vector provides a lower and an upper bound on the dependence distance +along the corresponding dimension. + +```mlir +test/Transforms/memref-dataflow-opt.mlir:232:7: note: dependence from 2 to 1 at depth 1 = ([1, 1], [-inf, +inf]) + store %cf9, %m[%idx] : memref<10xf32> +``` + +## Pipeline data transfer (`-affine-pipeline-data-transfer`) + +This pass performs a transformation to overlap non-blocking DMA operations in a +loop with computations through double buffering. This is achieved by advancing +dma_start operations with respect to other operations. + +Input + +```mlir +func @pipelinedatatransfer() { + %0 = alloc() : memref<256xf32> + %1 = alloc() : memref<32xf32, 1> + %2 = alloc() : memref<1xf32> + %c0 = constant 0 : index + %c128 = constant 128 : index + affine.for %i0 = 0 to 8 { + affine.dma_start %0[%i0], %1[%i0], %2[%c0], %c128 : memref<256xf32>, memref<32xf32, 1>, memref<1xf32> + affine.dma_wait %2[%c0], %c128 : memref<1xf32> + %3 = affine.load %1[%i0] : memref<32xf32, 1> + %4 = "compute"(%3) : (f32) -> f32 + affine.store %4, %1[%i0] : memref<32xf32, 1> + } + return +} +``` + +Output + +```mlir +module { + func @pipelinedatatransfer() { + %c8 = constant 8 : index + %c0 = constant 0 : index + %0 = alloc() : memref<256xf32> + %c0_0 = constant 0 : index + %c128 = constant 128 : index + %1 = alloc() : memref<2x32xf32, 1> + %2 = alloc() : memref<2x1xf32> + affine.dma_start %0[%c0], %1[%c0 mod 2, %c0], %2[%c0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> + affine.for %arg0 = 1 to 8 { + affine.dma_start %0[%arg0], %1[%arg0 mod 2, %arg0], %2[%arg0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> + %8 = affine.apply #map3(%arg0) + %9 = affine.apply #map4(%8) + %10 = affine.apply #map4(%8) + affine.dma_wait %2[%8 mod 2, symbol(%c0_0)], %c128 : memref<2x1xf32> + %11 = affine.load %1[%8 mod 2, %8] : memref<2x32xf32, 1> + %12 = "compute"(%11) : (f32) -> f32 + affine.store %12, %1[%8 mod 2, %8] : memref<2x32xf32, 1> + } + %3 = affine.apply #map3(%c8) + %4 = affine.apply #map4(%3) + %5 = affine.apply #map4(%3) + affine.dma_wait %2[%3 mod 2, symbol(%c0_0)], %c128 : memref<2x1xf32> + %6 = affine.load %1[%3 mod 2, %3] : memref<2x32xf32, 1> + %7 = "compute"(%6) : (f32) -> f32 + affine.store %7, %1[%3 mod 2, %3] : memref<2x32xf32, 1> + dealloc %2 : memref<2x1xf32> + dealloc %1 : memref<2x32xf32, 1> + return + } +} +``` diff --git a/mlir/docs/Quantization.md b/mlir/docs/Quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..99e450ca84dacde19b2691bab228c30e3ade7fd6 --- /dev/null +++ b/mlir/docs/Quantization.md @@ -0,0 +1,359 @@ +# MLIR Quantization + +This document outlines the design of the MLIR quantization system. While the +term "quantization" is highly overloaded, in this case, it refers to a fairly +narrow scope of techniques in use to enable conversion of floating-point +computations to corresponding and plausible variants expressed in integer math +for inference, as has historically been supported by low-bit depth inference +engines such as TFLite, various accelerator hardware, and many DSPs. + +Much of this is inspired by the approach taken +[in this paper](https://arxiv.org/abs/1712.05877) with many extensions and +adaptations folded in. It specifically documents the positions that MLIR has +taken on the topic, and is not a general reference. + +[TOC] + +## Uniform quantization + +The primary quantization mechanism supported by MLIR is a scheme which can +express fixed point and affine transformations via uniformly spaced point on the +Real number line. + +Further, the scheme can be applied: + +* *per-layer* : Applying to every value within the target type. +* *per-axis* (also called *per-channel*) : Applying individually to each index + along a specific axis of a tensor type. + +### Fixed point values + +[Fixed point](https://en.wikipedia.org/wiki/Fixed-point_arithmetic) values are a +[Real](https://en.wikipedia.org/wiki/Real_number) number divided by a *scale*. +We will call the result of the divided Real the *scaled value*. + +$$ real\_value = scaled\_value * scale $$ + +The scale can be interpreted as the distance, in Real units, between neighboring +scaled values. For example, if the scale is $$ \pi $$, then fixed point values +with this scale can only represent multiples of $$ \pi $$, and nothing in +between. The maximum rounding error to convert an arbitrary Real to a fixed +point value with a given $$ scale $$ is $$ \frac{scale}{2} $$. Continuing the +previous example, when $$ scale = \pi $$, the maximum rounding error will be $$ +\frac{\pi}{2} $$. + +Multiplication can be performed on scaled values with different scales, using +the same algorithm as multiplication of Real values (note that product scaled +value has $$ scale_{product} = scale_{left \mbox{ } operand} * scale_{right +\mbox{ } operand} $$). Addition can be performed on scaled values, as long as +they have the same scale, using the same algorithm as addition of Real values. +This makes it convenient to represent scaled values on a computer as signed +integers, and perform arithmetic on those signed integers, because the results +will be correct scaled values. + +### Affine values + +Mathematically speaking, affine values are the result of +[adding a Real-valued *zero point*, to a scaled value](https://en.wikipedia.org/wiki/Affine_transformation#Representation). +Or equivalently, subtracting a zero point from an affine value results in a +scaled value: + +$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$ + +Essentially, affine values are a shifting of the scaled values by some constant +amount. Arithmetic (i.e., addition, subtraction, multiplication, division) +cannot, in general, be directly performed on affine values; you must first +[convert](#affine-to-fixed-point) them to the equivalent scaled values. + +As alluded to above, the motivation for using affine values is to more +efficiently represent the Real values that will actually be encountered during +computation. Frequently, the Real values that will be encountered are not +symmetric around the Real zero. We also make the assumption that the Real zero +is encountered during computation, and should thus be represented. + +In this case, it's inefficient to store scaled values represented by signed +integers, as some of the signed integers will never be used. The bit patterns +corresponding to those signed integers are going to waste. + +In order to exactly represent the Real zero with an integral-valued affine +value, the zero point must be an integer between the minimum and maximum affine +value (inclusive). For example, given an affine value represented by an 8 bit +unsigned integer, we have: $$ 0 \leq zero\_point \leq 255$$. This is important, +because in deep neural networks' convolution-like operations, we frequently +need to zero-pad inputs and outputs, so zero must be exactly representable, or +the result will be biased. + +### Relation + +Real values, fixed point values, and affine values relate through the following +equation, which demonstrates how to convert one type of number to another: + +$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$ + +Note that computers generally store mathematical values using a finite number of +bits. Thus, while the above conversions are exact, to store the result in a +finite number of bits, we must, in general, round the result of the conversion +(this applies to both cases: storing using floating point and storing using +fixed point). Note that a full discussion of rounding behavior is outside the +scope of this document, and it is safe to assume unless otherwise stated that +rounding should be according to the IEEE754 default of RNE (where hardware +permits). + +### Converting between Real and fixed point or affine + +To convert a Real value to a fixed point value, you must know the scale. To +convert a Real value to an affine value, you must know the scale and zero point. + +#### Real to affine + +To convert an input tensor of Real-valued elements (usually represented by a +floating point format, frequently +[Single precision](https://en.wikipedia.org/wiki/Single-precision_floating-point_format)) +to a tensor of affine elements represented by an integral type (e.g. 8-bit +unsigned integer), the following conversion can be performed (note that it is +not required that all representable values of the integral type are used): + +$$ +\begin{align*} +af&fine\_value_{uint8 \, or \, uint16} \\ + &= clampToTargetSize(roundToNearestInteger( \frac{real\_value_{Single}}{scale_{Single}})_{sint32} + zero\_point_{uint8 \, or \, uint16}) +\end{align*} +$$ + +In the above, we assume that $$real\_value$$ is a Single, $$scale$$ is a Single, +$$roundToNearestInteger$$ returns a signed 32 bit integer, and $$zero\_point$$ +is an unsigned 8 or 16 bit integer. Note that bit depth and number of fixed +point values are indicative of common types on typical hardware but is not +constrained to particular bit depths or a requirement that the entire range of +an N-bit integer is used. + +#### Affine to Real + +To convert an output tensor of affine elements represented by uint8 +or uint16 to a tensor of Real-valued elements (usually represented with a +floating point format, frequently Single precision), the following conversion +can be performed: + +$$ +\begin{align*} +re&al\_value_{Single} \\ + &= roundToNearestFloat((affine\_value_{uint8 \, or \, uint16} - zero\_point_{uint8 \, or \, uint16})_{sint32})_{Single} * scale_{Single} +\end{align*} +$$ + +In the above, we assume that the result of subtraction is in 32-bit signed +integer format, and that $$roundToNearestFloat$$ returns a Single. + +#### Affine to fixed point + +When the affine and fixed point scales are the same, subtract the zero point +from the affine value to get the equivalent fixed point value. + +$$ +scaled\_value = affine\_value_{non\mbox{-}negative} - zero\_point_{non\mbox{-}negative} +$$ + +#### Fixed point to affine + +When the affine and fixed point scales are the same, add the zero point to the +fixed point value to get the equivalent affine value. + +$$ +affine\_value_{non\mbox{-}negative} = scaled\_value + zero\_point_{non\mbox{-}negative} +$$ + +## Usage within MLIR + +There are several components to the quantization system being developed within +MLIR: + +* *Quantization* dialect containing: + + * A family of [QuantizedTypes](#quantized-type) which represent the + mapping between *expressed* values (typically of a floating point + computer type) and *storage* values (typically of an integral computer + type). + * [Type conversion ops](#quantized-type-conversion-ops) for converting + between types based on a QuantizedType and its *expressed* and *storage* + sub-types. + * [Instrumentation ops](#instrumentation-and-constraint-ops) for assigning + instrumentation points within the computation where runtime statistics + may help guide the quantization process. + +* [Integration with simulated quantization at training time](#integration-with-simulated-quantization-at-training-time) + +* [TFLite native quantization](#tflite-native-quantization) + + * The TFLite op-set natively supports uniform-quantized variants. + * Passes and tools exist to convert directly from the *TensorFlow* dialect + to the TFLite quantized op-set. + +* [*FxpMath* dialect](#fxpmath-dialect) containing (experimental) generalized + representations of fixed-point math ops and conversions: + + * [Real math ops](#real-math-ops) representing common combinations of + arithmetic operations that closely match corresponding fixed-point math + concepts (as opposed to being spread across multiple ops as is typical + in source dialects). + * [Fixed-point math ops](#fixed-point-math-ops) that for carrying out + computations on integers, as are typically needed by uniform + quantization schemes. + * Passes to lower from real math ops to fixed-point math ops. + +* [Solver tools](#solver-tools) which can (experimentally and generically + operate on computations expressed in the *FxpMath* dialect in order to + convert from floating point types to appropriate *QuantizedTypes*, allowing + the computation to be further lowered to integral math ops. + +Not every application of quantization will use all facilities. Specifically, the +TensorFlow to TensorFlow Lite conversion uses the QuantizedTypes but has its own +ops for type conversion and expression of the backing math. + +## Quantization Dialect + +### Quantized type + +TODO : Flesh this section out. + +* QuantizedType base class +* UniformQuantizedType + +### Quantized type conversion ops + +* qcast : Convert from an expressed type to QuantizedType +* dcast : Convert from a QuantizedType to its expressed type +* scast : Convert between a QuantizedType and its storage type + +### Instrumentation and constraint ops + +* const_fake_quant : Emulates the logic of the historic TensorFlow + fake_quant_with_min_max_args op. +* stats_ref : Declares that statistics should be gathered at this point with a + unique key and made available to future passes of the solver. +* stats : Declares inline statistics (per layer and per axis) for the point in + the computation. stats_ref ops are generally converted to stats ops once + trial runs have been performed. +* coupled_ref : Declares points in the computation to be coupled from a type + inference perspective based on a unique key. + +## Integration with simulated quantization at training time + +TensorFlow has historically used the +[tf.quantization.fake_quant_\*](https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args) +family of operations to simulate the effect of quantization at training time. + +As originally implemented, TensorFlow Lite was the primary user of such +operations at inference time. When quantized inference was enabled, if every +eligible tensor passed through an appropriate fake_quant node (the rules of +which tensors can have fake_quant applied are somewhat involved), then +TensorFlow Lite would use the attributes of the fake_quant ops to make a +judgment about how to convert to use kernels from its quantized ops subset. + +In MLIR-based quantization, fake_quant_\* ops are handled by converting them to +a sequence of *qcast* (quantize) followed by *dcast* (dequantize) with an +appropriate *UniformQuantizedType* as the target of the qcast operation. + +This allows subsequent compiler passes to preserve the knowledge that +quantization was simulated in a certain way while giving the compiler +flexibility to move the casts as it simplifies the computation and converts it +to a form based on integral arithmetic. + +This scheme also naturally allows computations that are *partially quantized* +where the parts which could not be reduced to integral ops are still carried out +in floating point with appropriate conversions at the boundaries. + +## TFLite Native Quantization + +TODO : Flesh this out + +### General algorithm + +1. Take input min/max information and set the ArrayInfo (which really is + InputOrOutputArrayInfo. +1. In LegalizeTF, convert ArrayInfo min/max to tf.Quantize and tf.Dequantize + nodes. (or tf.FakeQuant) Convert all constant FakeQuants to (tf.FQ -> tfl.Q + -> tfl.DQ). +1. Hardcode logic/propagation needs to happen here. +1. Run TF constant folding. +1. In PrepareTFL, convert all tf.FQ to (tfl.Q -> tfl.DQ). +1. Run quantization pass that take (tfl.DQ (for both input and weights) -> op + -> tfl.Q) and replaces with (op). Also replace (constant_float -> tfl.Q) + with (constant_quant). + +## FxpMath Dialect + +### Real math ops + +Note that these all support explicit clamps, which allows for simple fusions and +representation of some common sequences quantization-compatible math. Of +addition, some support explicit biases, which are often represented as separate +adds in source dialects. + +TODO: This op set is still evolving and needs to be completed. + +* RealBinaryOp + * RealAddEwOp + * RealSubEwOp + * RealMulEwOp + * RealDivEwOp +* RealUnaryOp + * IDENTITY + * TANH + * SIGMOID + * EXP + * LOG + * NEG + * RSQRT + * SIN + * SQUARE + * SQRT + * CMPZ + * CMPNZ + * CMPLZ + * CMPGZ + +### Fixed-point math ops + +TODO: This op set only has enough ops to lower a simple power-of-two +RealAddEwOp. + +* RoundingDivideByPotFxpOp +* SaturatingAddFxpOp + +## Solver tools + +Solver tools exist to analyze an MLIR-computation, expressed in either a +supported source dialect or in the *real math ops* set and solve for appropriate +QuantizedTypes that allow the computation to be lowered to integral math. + +These tools are an active area of work and may be expanded in the future to +adjacent areas such as solving for transformations to other kinds of lower +precision types (i.e. bfloat16 or fp16). + +Solver tools are expected to operate in several modes, depending on the +computation and the manner in which it was trained: + +* *Transform* : With all available information in the MLIR computation, infer + boundaries where the computation can be carried out with integral math and + change types accordingly to appropriate QuantizedTypes: + + * For passthrough ops which do not perform active math, change them to + operate directly on the storage type, converting in and out at the edges + via scast ops. + * For ops that have the *Quantizable* trait, the type can be set directly. + This includes ops from the [real math ops set]{#real-math-ops}. + * For others, encase them in appropriate dcast/qcast ops, presuming that + some follow-on pass will know what to do with them. + +* *Instrument* : Most of the time, there are not sufficient implied + constraints within a computation to perform many transformations. For this + reason, the solver can insert instrumentation ops at points where additional + runtime statistics may yield solutions. It is expected that such + computations will be lowered as-is for execution, run over an appropriate + eval set, and statistics at each instrumentation point made available for a + future invocation of the solver. + +* *Simplify* : A variety of passes and simplifications are applied once + QuantizedTypes are added in order to arrive at a computation that is + expressed in as much integral math, with the fewest number of casts as + possible. diff --git a/mlir/docs/QuickstartRewrites.md b/mlir/docs/QuickstartRewrites.md new file mode 100644 index 0000000000000000000000000000000000000000..6a4a7cca8b88d9d0f282de9c82c78518f3f555c9 --- /dev/null +++ b/mlir/docs/QuickstartRewrites.md @@ -0,0 +1,255 @@ +# Quickstart tutorial to adding MLIR graph rewrite + +This document will present a quickstart to adding graph rewrites. We shall start +by defining an operation, showing multiple ways to define the rewrite using +patterns, as well as defining the rewrite using a graph walker (note: using +patterns and the rewrite engine is preferred, showing the walker is for +demonstration purposes). + +See [MLIR specification](LangRef.md) for more information about MLIR, the +structure of the IR, operations, etc. See +[Table-driven Operation Definition](OpDefinitions.md) and +[Declarative Rewrite Rule](DeclarativeRewrites.md) for the detailed explanation +of all available mechanisms for defining operations and rewrites in a +table-driven manner. + +## Adding operation + +An operation in MLIR is specified using a definition in +[TableGen](https://llvm.org/docs/TableGen/LangIntro.html) file. TableGen is a +modeling tool to specify the ops and the C++ code to interact with these +operations are generated from. To define an operation one needs to specify: + +* The operation name. This name is a unique identifier of the operation within + MLIR. Most operations are within a dialect, so for example one could have + `tfl.add` to represent the add operation in the TensorFlow Lite dialect. + Instead of repeating the dialect in the op definition, a base class for the + op dialect is commonly created that prepends the dialect namespace given an + op name. +* The traits of the operation. These allow you to specify traits of the + operation, such as whether it has side effects or whether it should be + verified that the operands and result types are the same. These are backed + by C++ traits that perform the verification. +* The arguments of the operation. These are the input operands (values at + runtime produced by other ops) and attributes (compile time known constant + values that affect the behavior of the op) that are the inputs of/define the + behavior of the operation. The input operands may be named, the attributes + must be named. +* The result(s) of the operation. These may again named or not. +* Documentation of the operation. This includes a one-line summary as well as + a longer human-readable description of the operation. +* Dialect specific information. Additional information could be added to the + operation definition that are only used by dialect specific drivers. These + are ignored by the main op and doc generators, but could be used in, say, + the translation from a dialect to another representation. + +```tablegen +def TFL_LeakyReluOp: TFL_Op, + Results<(outs Tensor)> { + let arguments = (ins + F32Tensor:$x, + // Slope of the activation function at x < 0. + F32Attr:$alpha + ); + + let summary = "Leaky ReLU operator"; + let description = [{ + Element-wise Leaky ReLU operator + x -> x >= 0 ? x : (alpha * x) + }]; + + // TFLite specific attribute that is used when generating the output + // flatbuffer. + let hasOptions = 1; +} +``` + +Note in the above the result types and inputs are specified in different ways, +one by way of trait and the other by way of let. It is possible to specify both +in either way. + + + +Operations can also have custom parser, printer, builder, verifier, constant +folder, or canonicalizer. These require specifying additional C++ methods to +invoke for additional functionality. For example, if an operation is marked to +have a folder, the constant folder also needs to be added, e.g.,: + +```c++ +OpFoldResult SpecificOp::fold(ArrayRef constOperands) { + if (unable_to_fold) + return {}; + .... + return val; +} +``` + +## Adding patterns + +There are multiple forms of graph rewrite that can be performed in MLIR. One of +the most common is DAG tile to DAG tile rewrite. Patterns provide a concise way +to express this transformation as a pair of source pattern to match and +resultant pattern. There are both the C++ classes to represent this +transformation, as well as the patterns in TableGen from which these can be +generated. + +### TableGen patterns + +Let us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to +TensorFlow Lite's `LeakyRelu`: + +```tablegen +def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)> +``` + +The pattern is specified by instantiating a `Pat` with a source and result DAG. +The arguments in the source pattern is captured and can be used in the result +pattern. This is a simple pattern as we have a 1:1 mapping and the attribute +does not need to be transformed (e.g., both have a floating point attribute for +alpha). The names of the attributes specified in the pattern is for +matching/referencing and need not match the original attribute name in the op +definition but the order of arguments of the dags do need to match. + +To specify a pattern, both the source and resultant ops need to be defined using +TableGen. + +If this were a more advance pattern that the current framework could not express +as destination then one could use a general native code fallback method. This +consists of defining a pattern as well as adding a C++ function to perform the +replacement: + +```tablegen +def createTFLLeakyRelu : NativeCodeCall< + "createTFLLeakyRelu($_builder, $0->getDefiningOp(), $1, $2)">; + +def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), + (createTFLLeakyRelu $old_value, $arg, $a)>; +``` + +```c++ +static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, + Value operand, Attribute attr) { + return rewriter.create( + op->getLoc(), operands[0]->getType(), /*arg=*/operands[0], + /*alpha=*/attrs[0].cast()); +} +``` + +This allows for arbitrarily complex builders. Input pattern side one can express +multi-op patterns with constraints on input operands and attributes. But input +patterns cannot yet express constraints across multiple operands/attributes. + +### Register the pattern + +The file containing the patterns need to be processed using `mlir-tblgen` +`-gen-rewriters` during compilation time. It can be invoked with the following +configuration in CMake: + +```cmake +set(LLVM_TARGET_DEFINITIONS ) +mlir_tablegen( -gen-rewriters) +add_public_tablegen_target() +``` + +Then you can `#include` the generated file in any C++ implementation file you +like. (You will also need to make sure the library depends on the CMake target +defined in the above.) The generated file will have a `populateWithGenerated( +MLIRContext *context, OwningRewritePatternList *patterns)` function that you can +use to collect all the generated patterns inside `patterns` and then use +`patterns` in any pass you would like. + +### C++ rewrite specification + +In case patterns are not sufficient there is also the fully C++ way of +expressing a rewrite: + +```c++ +/// Multi-step rewrite using "match" and "rewrite". This allows for separating +/// the concerns of matching and rewriting. +struct ConvertTFLeakyRelu : public RewritePattern { + ConvertTFLeakyRelu(MLIRContext *context) + : RewritePattern("tf.LeakyRelu", 1, context) {} + + PatternMatchResult match(Operation *op) const override { + return matchSuccess(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResult(0)->getType(), op->getOperand(0), + /*alpha=*/op->getAttrOfType("alpha")); + } +}; + +/// Single-step rewrite with "matchAndRewrite". This allows for performing the +/// rewrite immediately upon a successful match. +struct ConvertTFLeakyRelu : public RewritePattern { + ConvertTFLeakyRelu(MLIRContext *context) + : RewritePattern("tf.LeakyRelu", 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResult(0)->getType(), op->getOperand(0), + /*alpha=*/op->getAttrOfType("alpha")); + return matchSuccess(); + } +}; +``` + +In the C++ rewrite the static benefit of the rewrite pattern is specified at +construction. While in the pattern generator a simple heuristic is currently +employed based around the number of ops matched and replaced. + +The above rule did not capture the matching operands/attributes, but in general +the `match` function in a multi-step rewrite may populate and return a +`PatternState` (or class derived from one) to pass information extracted during +matching to the rewrite. A single-step rewrite with the `matchAndRewrite` +function has the benefit of being able to directly use any values created when +matching; removing the need for `PatternState`. + +## Testing + +MLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated +Testing) tool for performing testing. Testing is performed by way of creating +the input IR file, running a transformation and then verifying the output IR. +C++ unit tests are the exception, with the IR transformation serving as the core +testing mechanism. This results in fewer binaries that need to be built (and +linked) and forces to focus on the representation as an important piece. + +For the legalization transform above we would have a test (probably as part of +the legalization pass test in TensorFlow Lite) such as: + +```mlir +// RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s + +func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { + %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32> + return %2: tensor<1xf32> + +// CHECK-LABEL: LeakyRelu +// CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32> +} +``` + +The RUN command at the top results in running the `mlir-opt` binary (which is +compiler writer tool to exercise different registered passes) to invoke the +optimization pass this transform was added as part of on the current file and to +verify its output using `FileCheck`. `FileCheck` is textual output verifier. In +particular it uses the CHECK expressions to verify the given output is produced. + +There can be multiple RUN commands with different corresponding CHECK prefixes. +And in addition multiple independent tests separated by `// -----` and +`mlir-opt` invoked with `-split-input-file` flag. This is especially useful for +error testing. + +This results in very simple, directed testing without need to work around +constant propagation or other, unrelated, optimization passes. + +## Adding optimization pass + +Optimization passes that do not fit/are difficult to specify in the above +structure can be specified as general iterations across modules/functions. See +[Writing a Pass](WritingAPass.md) for a general overview and introduction to +optimization passes in MLIR. diff --git a/mlir/docs/Rationale.md b/mlir/docs/Rationale.md new file mode 100644 index 0000000000000000000000000000000000000000..763442dce0638342e314a5b31d4c32f0ab3d173c --- /dev/null +++ b/mlir/docs/Rationale.md @@ -0,0 +1,1121 @@ +# MLIR Rationale + +This document is intended to capture some of the alternatives considered and +open debates in the design of MLIR, along with the rationale for certain +decisions we made. This is not intended to be a "finely groomed" document - we +prefer the ability to dump in interesting tidbits without worrying too much +about their consistency or readability. + +[TOC] + +## Abstract + +MLIR is a compiler intermediate representation with similarities to traditional +three-address SSA representations (like +[LLVM IR](http://llvm.org/docs/LangRef.html) or +[SIL](https://github.com/apple/swift/blob/master/docs/SIL.rst)), but which +introduces notions from the polyhedral loop optimization works as first class +concepts. This hybrid design is optimized to represent, analyze, and transform +high level dataflow graphs as well as target-specific code generated for high +performance data parallel systems. Beyond its representational capabilities, its +single continuous design provides a framework to lower from dataflow graphs to +high performance target specific code. + +MLIR stands for one of "Multi-Level IR" or "Multi-dimensional Loop IR" or +"Machine Learning IR" or "Mid Level IR", we prefer the first. This document only +provides the rationale behind MLIR -- its actual +[specification document](LangRef.md) and other content is hosted elsewhere. + +## Introduction and Motivation + +The Multi-Level Intermediate Representation (MLIR) is intended for easy +expression and optimization of computations involving deep loop nests and dense +matrices of high dimensionality. It is thus well-suited to deep learning +computations in particular. Yet it is general enough to also represent arbitrary +sequential computation. The representation allows high-level optimization and +parallelization for a wide range of parallel architectures including those with +deep memory hierarchies --- general-purpose multicores, GPUs, and specialized +neural network accelerators. + +MLIR uses ideas drawn from IRs of LLVM and Swift for lower level constructs +while combining them with ideas from the polyhedral abstraction to represent +loop nests, multidimensional data (tensors), and transformations on these +entities as first class concepts in the IR. + +MLIR is a multi-level IR, i.e., it represents code at a domain-specific +representation such as HLO or TensorFlow graphs, all the way down to the machine +level. MLIR is able to represent arbitrary control flow and arbitrary data +accesses, and is general enough to represent nearly all sequential computation. +This is a key distinction from existing polyhedral representation +implementations (such as LLVM [Polly](https://polly.llvm.org/)) that are able to +use the polyhedral abstraction in a way isolated from the LLVM IR and only for +affine loop nests, i.e., portions of the code where array accesses, loop bounds, +and conditionals are regular (involve linear functions of loop iterators and +constant symbols). The presence of statically unpredictable data accesses or +control flow does not preclude representation in MLIR, but only limits to a +certain extent the ability to reason about and apply transformations using the +polyhedral abstraction. + +Maps, sets, and relations with affine constraints are the core structures +underlying a polyhedral representation of high-dimensional loop nests and +multidimensional arrays. These structures are represented as textual +expressions in a form close to their mathematical form. These structures are +used to capture loop nests, tensor data structures, and how they are reordered +and mapped for a target architecture. All structured or "conforming" loops are +captured as part of the polyhedral information, and so are tensor variables, +their layouts, and subscripted accesses to these tensors in memory. + +The information captured in the IR allows a compact expression of all loop +transformations, data remappings, explicit copying necessary for explicitly +addressed memory in accelerators, mapping to pre-tuned expert written +primitives, and mapping to specialized vector instructions. Loop transformations +that can be easily implemented include the body of affine transformations: these +subsume all traditional loop transformations (unimodular and non-unimodular) +such as loop tiling, interchange, permutation, skewing, scaling, relative +shifting, reversal, fusion, and distribution/fission. Transformations on data +layout such as padding and transforming to blocked layouts are also represented +well via affine layout maps. + +MLIR's design allows a progressive lowering to target-specific forms. Besides +high-level transformations for loop nests and data layouts that a typical +mid-level optimizer is expected to deal with, MLIR is also designed to perform +certain low-level scheduling and mapping decisions that a typical backend IR is +entrusted with: these include mapping to specialized vector instructions, +auto-vectorization, and software pipelining. The need to support these +transformations stems from the fact that neural network accelerators have +specialized units that deal with large chunks of data whose computation maps +back to chunks of more than one loop of the loop nests as viewed by a program at +a level closer to the original specification. Such specialized units or +instructions operate on multidimensional data chunks from a programmer's +viewpoint. It thus makes it hard or infeasible for a backend operating on a very +low-level IR close to assembly to lift and reconstruct loops and perform such a +mapping. This is in contrast to classic instruction selection and scheduling in +today's compilers that primarily only deals with the body of the innermost loop. +MLIR also facilitates automatic mapping to expert pre-tuned primitives or vendor +libraries operating on data at higher levels (or at the highest level) of the +memory hierarchy. + +In summary, MLIR is convenient for and closed under the kind of transformations +needed to lower to general-purpose as well as specialized accelerators. It also +allows one to build modular and reusable target independent and target dependent +passes. + +## Design Decisions + +This section sheds light on some of the design decisions -- some of these are +indirectly implied by the specification document. + +### Loads and stores + +The 'load' and 'store' instructions are specifically crafted to fully resolve to +an element of a memref. These instructions take as arguments n+1 indices for an +n-ranked tensor. This disallows the equivalent of pointer arithmetic or the +ability to index into the same memref in other ways (something which C arrays +allow for example). Furthermore, for the affine constructs, the compiler can +follow use-def chains (e.g. through +[affine.apply operations](Dialects/Affine.md#affineapply-operation)) or through +the map attributes of [affine operations](Dialects/Affine.md#Operations)) to +precisely analyze references at compile-time using polyhedral techniques. This +is possible because of the [restrictions on dimensions and symbols](Dialects/Affine.md#restrictions-on-dimensions-and-symbols). + +A scalar of element-type (a primitive type or a vector type) that is stored in +memory is modeled as a 0-d memref. This is also necessary for scalars that are +live out of for loops and if conditionals in a function, for which we don't yet +have an SSA representation -- +[an extension](#mlfunction-extensions-for-"escaping-scalars") to allow that is +described later in this doc. + +### Symbols and types + +The current MLIR disallows use of symbols in types. For example, when a tensor +or memref dimension is statically unknown, it is denoted in the type as '?'. An +SSA symbol is then bound to it when a memref is created. The actual value of the +unknown dimension can be queried using the "dim" builtin as shown below. + +Example: + +```mlir +func foo(...) { + %A = alloc <8x?xf32, #lmap> (%N) + ... + call bar(%A) : (memref<8x?xf32, #lmap>) +} + +func bar(%A : memref<8x?xf32, #lmap>) { + // Type of %A indicates that %A has dynamic shape with 8 rows + // and unknown number of columns. The number of columns is queried + // dynamically using dim instruction. + %N = dim %A, 1 : memref<8x?xf32, #lmap> + + affine.for %i = 0 to 8 { + affine.for %j = 0 to %N { + // A[i,j] += 1 + %s1 = affine.load %A[%i, %j] : memref<8x?xf32, #lmap> + %s2 = add %s1, 1 + affine.store %s2, %A[%i, %j] : memref<8x?xf32, #lmap> + } + } + return +} + +``` + +An alternative design is to embed the reference to symbols directly in the +type - memref<8x%Nxf32>. We went for the current approach in MLIR because it +simplifies the design --- types remain immutable when the values of symbols +change. + +### Block Arguments vs PHI nodes + +MLIR Regions represent SSA using "[block arguments](LangRef.md#blocks)" rather +than [PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in LLVM. +This choice is representationally identical (the same constructs can be +represented in either form) but block arguments have several advantages: + +1. LLVM PHI nodes always have to be kept at the top of a block, and + transformations frequently have to manually skip over them. This is defined + away with BB arguments. +1. LLVM has a separate function Argument node. This is defined away with BB + arguments, because the arguments to the entry block serve this purpose. +1. Blocks of PHI nodes in LLVM execute atomically, which is surprising and + super confusing to compiler engineers and it is easy to introduce bugs with + this (very related to the + "[lost copy](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf)" + problem in SSA lowering literature.) With the BB argument representation, + this confusion is defined away. +1. The entry list of PHI nodes in LLVM are unordered, and some blocks have + thousands of predecessors (e.g. unwind blocks). This can cause long compile + time problems because transformations have to linearly scan this list. This + is defined away with BB argument representation. +1. LLVM has no way to represent values that are available only in one successor + but not the other, e.g. its invoke instruction cannot produce the exception + value JUST on the exception edge. Instead, the + [landingpad instruction](http://llvm.org/docs/LangRef.html#landingpad-instruction) + is a hack used to represent this. MLIR doesn't make use of this capability, + but SIL uses it extensively, e.g. in the + [switch_enum instruction](https://github.com/apple/swift/blob/master/docs/SIL.rst#switch-enum). + +For more context, block arguments were previously used in the Swift +[SIL Intermediate Representation](https://github.com/apple/swift/blob/master/docs/SIL.rst), +and described in +[a talk on YouTube](https://www.youtube.com/watch?v=Ntj8ab-5cvE). The section of +interest +[starts here](https://www.google.com/url?q=https://youtu.be/Ntj8ab-5cvE?t%3D596&sa=D&ust=1529450150971000&usg=AFQjCNFQHEWL7m8q3eO-1DiKw9zqC2v24Q). + +### Index type disallowed in vector/tensor/memref types + +Index types are not allowed as elements of `vector`, `tensor` or `memref` type. +Index types are intended to be used for platform-specific "size" values and may +appear in subscripts, sizes of aggregate types and affine expressions. They are +also tightly coupled with `affine.apply` and affine.load/store operations; +having `index` type is a necessary precondition of a value to be acceptable by +these operations. While it may be useful to have `memref` to express +indirect accesses, e.g. sparse matrix manipulations or lookup tables, it creates +problems MLIR is not ready to address yet. MLIR needs to internally store +constants of aggregate types and emit code operating on values of those types, +which are subject to target-specific size and alignment constraints. Since MLIR +does not have a target description mechanism at the moment, it cannot reliably +emit such code. Moreover, some platforms may not support vectors of type +equivalent to `index`. + +Indirect access use cases can be alternatively supported by providing and +`index_cast` instruction that allows for conversion between `index` and +fixed-width integer types, at the SSA value level. It has an additional benefit +of supporting smaller integer types, e.g. `i8` or `i16`, for small indices +instead of (presumably larger) `index` type. + +### Bit width of a non-primitive types and `index` is undefined + +The bit width of a compound type is not defined by MLIR, it may be defined by a +specific lowering pass. In MLIR, bit width is a property of certain primitive +_type_, in particular integers and floats. It is equal to the number that +appears in the type definition, e.g. the bit width of `i32` is `32`, so is the +bit width of `f32`. The bit width is not _necessarily_ related to the amount of +memory (in bytes) or the size of register (in bits) that is necessary to store +the value of the given type. These quantities are target and ABI-specific and +should be defined during the lowering process rather than imposed from above. +For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit +integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, rather +than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the +bitwidth. Individual components of MLIR that allocate space for storing values +may use the bit size as the baseline and query the target description when it is +introduced. + +The bit width is not defined for dialect-specific types at MLIR level. Dialects +are free to define their own quantities for type sizes. + +### Signless types + +Integers in the builtin MLIR type system have a bitwidth (note that the `index` +type has a symbolic width equal to the machine word size), but they do not have +an intrinsic sign. This means that the "standard ops" operation set has things +like `addi` and `muli` which do two's complement arithmetic, but some other +operations get a sign, e.g. `divis` vs `diviu`. + +LLVM uses the [same design](http://llvm.org/docs/LangRef.html#integer-type), +which was introduced in a revamp rolled out +[in the LLVM 2.0 integer type](http://releases.llvm.org/2.0/docs/LangRef.html#t_derived). +Prior to that, from +[LLVM 1.0](http://releases.llvm.org/1.0/docs/LangRef.html#t_classifications) to +[1.9](http://releases.llvm.org/1.9/docs/LangRef.html#t_classifications), LLVM +uses signed types like "sbyte" and "ubyte". This shift was important and has +served LLVM well over the years. The reason this is important is that it is a +good thing for an intermediate representation to represent the same computation +with the same instruction. Signed types got in the way, because (e.g.) an "add +of an sbyte" does the same computation as an "add of a ubyte", but the type +system made them look artificially different. This split also required casts +like "cast from sbyte to ubyte" which do nothing at the machine level. Removing +signs from the type system eliminated these problems, making the compiler +simpler. + +More information about this split is available in an old +[talk on youtube](https://www.youtube.com/watch?v=VeRaLPupGks) talking about +LLVM 2.0. + +Note that this rationale only applies to the "standard ops" dialect in which we +can express an opinion about its design. Other dialects generally try to model +an external system, and should aim to reflect its design as closely as possible. + +### Splitting floating point vs integer operations + +The MLIR "standard" operation set splits many integer and floating point +operations into different categories, for example `addf` vs `addi` and `cmpf` vs +`cmpi` +([following the design of LLVM](http://llvm.org/docs/LangRef.html#binary-operations)). +These instructions _are_ polymorphic on the number of elements in the type +though, for example `addf` is used with scalar floats, vectors of floats, and +tensors of floats (LLVM does the same thing with its scalar/vector types). + +This split is important because floating point and integer operations are quite +different in practice: for example, floating point values include NaN's, so +[integer comparisons](http://llvm.org/docs/LangRef.html#icmp-instruction) and +[floating point comparisons](http://llvm.org/docs/LangRef.html#fcmp-instruction) +should use different comparison opcodes. On the arithmetic side of things, +floating point operations support rounding modes, floating point contractions, +["fast math"](http://llvm.org/docs/LangRef.html#fadd-instruction), and integers +may want to have two's complement overflow behavior or be undefined on +[various forms of wrapping](http://llvm.org/docs/LangRef.html#add-instruction) +for performance. + +We are a long way from this sort of thing being a priority to care about in +MLIR, but since we have experience and know the right way to do this, we'd +rather design it in from the beginning. + +Note that this rationale only applies to the "standard ops" dialect in which we +can express an opinion about its design. Other dialects generally try to model +an external system, and should aim to reflect its design as closely as possible. + +### Specifying sign in integer comparison operations + +Since integers are [signless](#signless-types), it is necessary to define the +sign for integer comparison operations. This sign indicates how to treat the +foremost bit of the integer: as sign bit or as most significant bit. For +example, comparing two `i4` values `0b1000` and `0b0010` yields different +results for unsigned (`8 > 3`) and signed (`-8 < 3`) interpretations. This +difference is only significant for _order_ comparisons, but not for _equality_ +comparisons. Indeed, for the latter all bits must have the same value +independently of the sign. Since both arguments have exactly the same bit width +and cannot be padded by this operation, it is impossible to compare two values +whose bit representations would differ while the values are interpreted as +equal. + +### Specifying comparison kind as attribute + +Unlike arithmetic, comparison operators share several common properties, e.g. +they cannot be considered associative. In practice, comparisons are sometimes +implemented by the same instruction or its variants so it makes sense to group +them together at the IR level. + +An alternative would be introducing ten distinct operators for all currently +supported kinds of integer comparisons. These operators would have increased the +number of "reserved" names used by standard operations as well as the size of +the C++ API while their implementations would have been mostly identical. + +The comparison kind is internally an integer attribute. However, for the sake of +readability by humans, custom assembly form accepts string literals that are +mapped to the underlying integer values: `cmpi "eq", %lhs, %rhs` better implies +integer equality comparison than `cmpi 0, %lhs, %rhs` where it is unclear what +gets compared to what else. This syntactic sugar is possible thanks to parser +logic redefinitions for custom assembly form of non-builtin operations. +Supporting it in the full notation would have required changing how the main +parsing algorithm works and may have unexpected repercussions. While it had been +possible to store the predicate as string attribute, it would have rendered +impossible to implement switching logic based on the comparison kind and made +attribute validity checks (one out of ten possible kinds) more complex. + +### 'select' operation to implement min/max + +Although `min` and `max` operations are likely to occur as a result of +transforming affine loops in ML functions, we did not make them first-class +operations. Instead, we provide the `select` operation that can be combined with +`cmpi` to implement the minimum and maximum computation. Although they now +require two operations, they are likely to be emitted automatically during the +transformation inside MLIR. On the other hand, there are multiple benefits of +introducing `select`: standalone min/max would concern themselves with the +signedness of the comparison, already taken into account by `cmpi`; `select` can +support floats transparently if used after a float-comparison operation; the +lower-level targets provide `select`-like instructions making the translation +trivial. + +This operation could have been implemented with additional control flow: `%r = +select %cond, %t, %f` is equivalent to + +```mlir +^bb0: + cond_br %cond, ^bb1(%t), ^bb1(%f) +^bb1(%r): +``` + +However, this control flow granularity is not available in the ML functions +where min/max, and thus `select`, are likely to appear. In addition, simpler +control flow may be beneficial for optimization in general. + +### Regions + +#### Attributes of type 'Block' + +We considered representing regions through `ArrayAttr`s containing a list of a +special type `IRBlockAttr`, which in turn would contain a list of operations. +All attributes in MLIR are unique’d within the context, which would make the IR +inside the regions immortal for no good reason. + +#### Use "inlined" functions as regions + +We considered attaching a "force-inline" attribute on a function and/or a +function `call` operation. Even the minimal region support (use cases in +affine.for and affine.if existing before the regions) requires access to the +values defined in the dominating block, which is not supported by functions. +Conceptually, function bodies are instances of regions rather than the inverse; +regions can also be device kernels, alternative sections, etc. + +#### Dedicated `region` operation + +This would mean we have a special kind of operation that is allowed to have +regions while other operations are not. Such distinction is similar to the +Stmt/Op difference we have had and chose to remove to make the IR simpler and +more flexible. It would also require analyses and passes to consider the +interplay between operations (e.g., an `affine.for` operation must be followed +by a region operation). Finally, a region operation can be introduced using the +current implementation, among other operations and without being special in any +sense. + +#### Explicit capture of the values used in a region + +Being able to use values defined outside the region implies that use-def chains +may contain uses from different nested regions. Consequently, IR transformations +and analyses can pull the instruction defining the value across region +boundaries, for example in case of TableGen-defined canonicalization patterns. +This would not be the case if all used values had been passed as region +arguments. One of the motivations for introducing regions in the IR is precisely +to enable cross-region analyses and transformations that are simpler than +inter-procedural transformations. Having uses from different regions appear in +the same use-def chain, contrary to an additional data structure maintaining +correspondence between function call arguments as uses of the original +definitions and formal arguments as new definitions, enables such +simplification. Since individual operations now belong to blocks, which belong +to regions, it is always possible to check if the definition of the value +belongs to the same region as its particular use. The risk is that any IR +traversal will need to handle explicitly this situation and it is easy to forget +a check (or conversely it isn’t easy to design the right check in a tablegen +pattern for example): traversing use-def chains potentially crosses implicitly +semantic barriers, making it possible to unknowingly break region semantics. +This is expected to be caught in the verifier after the transformation. + +At the same time, one may choose to pass certain or all values as region +arguments to explicitly break the use-def chains in the current proposal. This +can be combined with an attribute-imposed semantic requirement disallowing the +body of the region to refer to any value from outside it. + +### Quantized integer operations + +We haven't designed integer quantized operations in MLIR, but experience from +TensorFlow suggests that it is better to put information about the quantization +range/scale into the type itself, rather than have a single type like "qint8" +and put these on attributes of the operation. + +There are a few ways to do this with MLIR, including at least: + +* We could do the same thing TensorFlow does - and we will _have_ to support + that model to some extent for compatibility. +* We can encode the fp range of quantized integers directly into the types + when they are constants. The best practice on this seems to be to encode the + zero point as well as a scale factor. This ensures that 0.0 is always + exactly representable, e.g. `qi8<-1.42, 31.23x>`. +* We could theoretically encode dynamically determined ranges into the types + using something like `qi8` with the bounds being determined through the + SSA dataflow graph dynamically - similar to how dynamic shapes are handled. + +We will definitely need to do #1 for compatibility, we probably want to do #2, +and we should investigate #3 over time. That said, our short term plan is to get +more implementation experience with the rest of the system first, then come back +to re-examine the representation for quantized arithmetic when we have that +experience. When we do, we should chat with benoitjacob@ and +[read the paper](https://arxiv.org/abs/1712.05877). + +### Dialect type extensions + +This section describes the design decisions that shaped the dialect extensible +type system present in MLIR. + +#### Reserving dialect type kinds + +Dialects that wish to define type extensions must reserve a range of type kinds +within a '.def' file within the core IR library. This means that every dialect +wishing to define custom types must modify this file, but it guarantees that all +type casting checkings are performed in O(1) time. + +#### Interactions between dialects + +There are two different interactions between dialects that are important to +understand. When types of a dialect are: + +* In operations of other dialects + + - For standard/builtin operations, only standard/builtin types are + allowed. This restriction allows for operations to clearly understand + the invariants that they are working under. + - Outside of standard/builtin operations, dialects are expected to verify + the allowable operation types per operation. + +* In types of other dialects + + - For standard/builtin types, these types are allowed to contain types + from other dialects. This simplifies the type system and removes the + need for dialects to redefine all of the standard aggregate types, e.g. + tensor, as well as the memref type. Dialects are expected to verify that + a specific type is valid within a standard type, e.g. if a type can be + an element of a tensor. + - For dialect types, the dialect is expected to verify any type + invariants, e.g. if the standard tensor type can contain a specific type + of that dialect. + +#### Separating builtin and standard types + +Following the separation between the built-in and standard dialect, it makes +sense to separate built-in types and standard dialect types. Built-in types are +required for the validity of the IR itself, e.g. the function type (which +appears in function signatures and generic assembly forms of operations). +Integer, float, vector, memref and tensor types, while important, are not +necessary for IR validity. + +#### Unregistered types + +MLIR supports unregistered operations in generic assembly form. MLIR also +supports a similar concept for types. When parsing, if the dialect for dialect +type has not been registered the type is modeled as an 'OpaqueType'. This allows +for types to be round-tripped without needing to link in the dialect library +that defined them. No additional information about opaque types, outside of +parsing/printing, will be available. + +#### Dialect type syntax + +Dialect extended types are represented as string literals wrapped inside of the +dialect namespace. This means that the parser delegates to the dialect for +parsing specific type instances. This differs from the representation of dialect +defined operations, of which have an identifier name that the parser uses to +identify and parse them. + +This representation was chosen for several reasons: + +##### Dialects must provide custom type parsers + +Dialect type parsing cannot plug into the existing parser infrastructure as +operations do with the OpAsmParser/Printer. Operations have a defined syntax +structure that is the same across all dialects. Types, on the other hand, may +have many different, and sometimes conflicting, parsing constraints that would +be difficult/unmaintainable to provide within a single interface. + +This also has the added benefit of encouraging dialects to reuse existing +external type parsers. For example, an LLVM dialect may provide an MLIR LLVM +type that is simply a wrapper around LLVM types. The LLVM dialect would then use +the existing LLVM type parsing infrastructure. + +Example: + +```mlir +%s = "foo"() : () -> !llvm<"i32*"> +``` + +##### Types do not always have canonical names + +Unlike operations, types generally do not have a formal canonical name. For +example, function types have no defined keyword and integer types are defined by +a regular expression to support arbitrary bitwidth. Dialects with existing type +systems, e.g. LLVM, are likely to provide wrappers around their existing type +systems. For these wrapper types there is no simple canonical name, it's logical +to think of these types as existing within the namespace of the dialect. If a +dialect wishes to assign a canonical name to a type, it can be done via +[type aliases](LangRef.md#type-aliases). + +### Tuple types + +The MLIR type system provides first class support for defining +[tuple types](LangRef.md#tuple-type). This is due to the fact that `Tuple` +represents a universal concept that is likely to, and has already begun to, +present itself in many different dialects. Though this type is first class in +the type system, it merely serves to provide a common mechanism in which to +represent this concept in MLIR. As such, MLIR provides no standard operations +for interfacing with `tuple` types. It is up to dialect authors to provide +operations, e.g. extract_tuple_element, to interpret and manipulate them. When +possible, operations should prefer to use multiple results instead. These +provide a myriad of benefits, such as alleviating any need for tuple-extract +operations that merely get in the way of analysis and transformation. + +### Assembly forms + +MLIR decides to support both generic and custom assembly forms under the +following considerations: + +MLIR is an open system; it is designed to support modular and pluggable +dialects. Depending on whether there exists a corresponding dialect and whether +the dialect is plugged in, operations may or may not be registered into MLIR +system. Yet we still need a way to investigate these operations. So the generic +assembly form is mandated by this aspect of MLIR system. It provides a default +textual form for operations. + +On the other hand, an assembly form is for assisting developers to investigate +the IR. The generic form serves as a safe fallback but it can be too verbose for +certain ops. Therefore, MLIR gives each dialect the choice to define a custom +assembly form for each operation according to the operation's semantics and +specific needs. The custom assembly form can de-duplicate information from the +operation to derive a more concise form, thus better facilitating the +comprehension of the IR. + +## Examples + +This section describes a few very simple examples that help understand how MLIR +represents computation. + +### Non-affine control flow + +```mlir +// A simple linear search in every row of a matrix +for (i = 0; i < N; i++) { + for (j = 0; j < N; j++) { + // dynamic control flow + if (a[i][j] == key) { + s[i] = j; + break; + } + } +} +``` + +The presence of dynamic control flow leads to an inner non-affine function +nested in an outer function that using affine loops. + +```mlir +func @search(%A: memref, %key : i32) { + %ni = dim %A, 0 : memref + // This loop can be parallelized + affine.for %i = 0 to %ni { + call @search_body (%A, %S, %key, %i) : (memref, memref, i32, i32) + } + return +} + +func @search_body(%A: memref, %S: memref, %key: i32, %i : i32) { + %nj = dim %A, 1 : memref + br ^bb1(0) + +^bb1(%j: i32) + %p1 = cmpi "lt", %j, %nj : i32 + cond_br %p1, ^bb2, ^bb5 + +^bb2: + %v = affine.load %A[%i, %j] : memref + %p2 = cmpi "eq", %v, %key : i32 + cond_br %p2, ^bb3(%j), ^bb4 + +^bb3(%j: i32) + affine.store %j, %S[%i] : memref + br ^bb5 + +^bb4: + %jinc = addi %j, 1 : i32 + br ^bb1(%jinc) + +^bb5: + return +} +``` + +As per the [MLIR spec](LangRef.md), the restrictions on dimensions and symbol +identifiers to be used with the affine.apply operation only apply to accesses +inside `affine.for` and `affine.if` operations. However, an analysis of accesses +inside the called function (`@search_body`) is necessary to determine if the +`%i` loop could be parallelized: such function access analysis is calling +context sensitive. + +### Non-affine loop bounds + +Loop bounds that are not affine lead to a nesting of functions as shown below. + +```c +for (i = 0; i < N; i++) +  for (j = 0; j < N; j++) + // Non-affine loop bound for k loop. +    for (k = 0; k < pow(2, j); k++) +       for (l = 0; l < N; l++) { +        // block loop body +        ... +       } +``` + +```mlir +func @outer_nest(%n : index) { + affine.for %i = 0 to %n { + affine.for %j = 0 to %n { + %pow = call @pow(2, %j) : (index, index) -> index + call @inner_nest(%pow, %n) : ... + } + } + return +} + +func @inner_nest(%m : index, %n : index) { + affine.for %k = 0 to %m { + affine.for %l = 0 to %n { + ... + } + } + return +} +``` + +### Reference 2D Convolution + +The following example illustrates a reference implementation of a 2D +convolution, which uses an integer set `#domain` to represent valid input data +in a dilated convolution. + +```mlir +// Dilation factors S0 and S1 can be constant folded if constant at compile time. +#domain = (d0, d1)[S0,S1,S2,S3]: (d0 % S0 == 0, d1 % S1 == 0, d0 >= 0, d1 >= 0, + S3 - d0 - 1 >= 0, S4 - d1 - 1 >= 0) +// Identity map (shown here for illustration). +#map0 = (d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6) + +// Affine map from output to input coordinate space. +// d0 = output_h, d1 = output_w, d2 = kernel_h, d3 = kernel_w +// S0 = h_stride, S1 = w_stride, S2 = h_kernel_dilation, S3 = w_kernel_dilation +// S4 = h_pad_low, S5 = w_pad_low +// %out0 = %0#1 * %h_stride + %0#4 * %h_kernel_dilation - %h_pad_low +// %out1= %0#2 * %w_stride + %0#5 * %w_kernel_dilation - %w_pad_low +#map1_0 = (d0, d1, d2, d3) [S0, S1, S2, S3, S4, S5] -> (d0 * S0 + d2 * S2 - %S4) +#map1_1 = (d0, d1, d2, d3) [S0, S1, S2, S3, S4, S5] -> (d1 * S1 + d3 * S3 - %S5) + +// Semi-affine map to undilated input coordinate space. +// d0 = input_h, d1 = input_w, S0 = h_base_dilation, S1 = w_base_dilation. +#map2_0 = (d0, d1) [S0, S1] -> (d0 / S0) +#map2_1 = (d0, d1) [S0, S1] -> (d1 / S1) + +// Conv2D shapes: +// input: [batch, input_height, input_width, input_feature] +// kernel: [kernel_height, kernel_width, input_feature, output_feature] +// output: [batch, output_height, output_width, output_feature] +func @conv2d(%input: memref<16x1024x1024x3xf32, #lm0, /*scratchpad=*/1>, + %kernel: memref<5x5x3x32xf32, #lm0, /*scratchpad=*/1>, + %output: memref<16x512x512x32xf32, #lm0, /*scratchpad=*/1>) { + affine.for %b = 0 to %batch { + affine.for %oh = 0 to %output_height { + affine.for %ow = 0 to %output_width { + affine.for %of = 0 to %output_feature { + affine.for %kh = 0 to %kernel_height { + affine.for %kw = 0 to %kernel_width { + affine.for %if = 0 to %input_feature { + // Calculate input indices. + %1_0 = affine.apply #map1_0 (%0#1, %0#2, %0#4, %0#5) + [%h_stride, %w_stride, %h_kernel_dilation, %w_kernel_dilation, + %h_pad_low, %w_pad_low] + %1_1 = affine.apply #map1_1 (%0#1, %0#2, %0#4, %0#5) + [%h_stride, %w_stride, %h_kernel_dilation, %w_kernel_dilation, + %h_pad_low, %w_pad_low] + + // Check if access is not in padding. + affine.if #domain(%1_0, %1_1) + [%h_base_dilation, %w_kernel_dilation, %h_bound, %w_bound] { + %2_0 = affine.apply #map2 (%1_0, %1_1) + %2_1 = affine.apply #map2 (%1_0, %1_1) + // Compute: output[output_indices] += input[input_indices] * kernel[kernel_indices] + call @multiply_accumulate(%input, %kernel, %output, %b, %oh, %ow, %of, %kh, %kw, %if, %2_0, %2_1) + } + } + } + } + } + } + } + } + return +} +``` + +TODO (Add more examples showing the IR for a variety of interesting cases) + +## Design alternatives and extensions + +This is a list of some design alternatives and extensions that we discussed in +detail but did not include in the spec or postponed them for future +consideration on demand. We will revisit these discussions when we have more +implementation experience and learn more about the challenges and limitations of +our current design in practice. + +### Polyhedral code representation alternatives: schedule lists vs schedules trees vs affine loop/if forms + +The current MLIR uses a representation of polyhedral schedules using a tree of +if/for loops. We extensively debated the tradeoffs involved in the typical +unordered polyhedral instruction representation (where each instruction has +multidimensional schedule information), discussed the benefits of schedule tree +forms, and eventually decided to go with a syntactic tree of affine if/else +conditionals and affine for loops. Discussion of the tradeoff was captured in +this document: +[ MLIR: The case for a simplified polyhedral form](RationaleSimplifiedPolyhedralForm.md). + +At a high level, we have two alternatives here: + +1. Schedule tree representation instead of an affine loop AST form: The current + proposal uses an affine loop and conditional tree form, which is syntactic + and with no separation of domains as sets and schedules as multidimensional + affine functions. A schedule tree form however makes polyhedral domains and + schedules a first class concept in the IR allowing compact expression of + transformations through the schedule tree without changing the domains of + instructions. Such a representation also hides prologues, epilogues, partial + tiles, complex loop bounds and conditionals making loop nests free of + "syntax". Cost models instead look at domains and schedules. In addition, if + necessary such a domain schedule representation can be normalized to + explicitly propagate the schedule into domains and model all the cleanup + code. An example and more detail on the schedule tree form is in the next + section. +1. Having two different forms of "affine regions": an affine loop tree form + and a polyhedral schedule tree form. In the latter, ops could carry + attributes capturing domain, scheduling, and other polyhedral code + generation options with IntegerSet, AffineMap, and other attributes. + +#### Schedule Tree Representation for Affine Regions + +This representation is based on a simplified form of the domain/schedule +representation used by the polyhedral compiler community. Domains represent what +has to be executed while schedules represent the order in which domain elements +are interleaved. We model domains as non-piece-wise convex integer sets, and +schedules as affine functions; however, the former can be disjunctive, and the +latter can be piece-wise affine relations. In the schedule tree representation, +domain and schedules for instructions are represented in a tree-like structure +which is called a schedule tree. Each non-leaf node of the tree is an abstract +polyhedral dimension corresponding to an abstract fused loop for each ML +instruction that appears in that branch. Each leaf node is an ML Instruction. + +```mlir +// A tiled matmul code (128x128x128) represented in schedule tree form + +// #map0 = (d0, d1, d2, d3, d4, d5) -> (128*d0 + d3, 128*d1 + d4, 128*d2 + d5) +#intset_ij = (i, j) [M, N, K] : i >= 0, -i + N - 1 >= 0, j >= 0, -j + N-1 >= 0 +#intset_ijk = (i, j, k) [M, N, K] : i >= 0, -i + N - 1 >= 0, j >= 0, + -j + M-1 >= 0, k >= 0, -k + N - 1 >= 0) +func @matmul(%A, %B, %C, %M, %N, %K) : (...) { // %M, N, K are symbols + // t1, t2, t3, t4, t5, t6 are abstract polyhedral loops + mldim %t1 : {S1,S2,S3,S4,S5} floordiv (i, 128) { + mldim %t2 : {S1,S2,S3,S4,S5} floordiv (j, 128) { + // (%i, %j) = affine.apply (d0, d1) -> (128*d0, 128*d1) (%t1, %t2) + call dma_mem_to_scratchpad(%C, %i, %j, %M, %N, %K) + with @intset_ij(%i, %j) [%M, %N, %K] + mldim %t3 : {S2,S3,S4,S5} floordiv (k, 128) { + // (%i, %j, %k) = affine.apply (d0, d1, d2) + // -> (128*d0, 128*d1, 128*d2) (%t1, %t2, %t3) + call dma_mem_to_scratchpad(%A, ...) with #inset_ijk (%i, %j, %k) [%M, %N, %K] + // (%i, %j, %k) = affine.apply (d0, d1, d2) + // -> (128*d0, 128*d1, 128*d2) (%t1, %t2, %t3) + call dma_mem_to_scratchpad(%B, ...) with #inset_ijk (%i, %j, %k) [%M, %N, %K] + mldim %t4 : {S4} i mod 128 { + mldim %t5 : {S4} j mod 128 { + mldim %t6 : {S4} k mod 128 { + // (%i, %j, %k) = affine.apply #map0 (%t1, %t2, %t3, %t4, %t5, %t6) + call matmul_body(A, B, C, %i, %j, %k, %M, %N, %K) + with #inset_ijk(%i, %j, %k) [%M, %N, %K] + } // end mld4im t6 + } // end mldim t5 + } // end mldim t4 + } // end mldim t3 + // (%i, %j) = affine.apply (d0, d1) -> (128*d0, 128*d1) (%t1, %t2) + call $dma_scratchpad_to_mem_C ... with #intset(%i, %j) [%M, %N, %K] + } // end mldim t2 + } // end mldim t1 + return +} + +``` + +### Affine Relations + +The current MLIR spec includes affine maps and integer sets, but not affine +relations. Affine relations are a natural way to model read and write access +information, which can be very useful to capture the behavior of opaque external +library calls, high-performance vendor libraries, or user-provided / user-tuned +routines. + +An affine relation is a relation between input and output dimension identifiers +while being symbolic on a list of symbolic identifiers and with affine +constraints on the identifiers. + +Syntax: + +``` +// Affine relation definition at the top of file +affine-rel-def ::= affine-rel-id `=` affine-relation-inline + +affine-rel-id ::= `##` prefixed-id + +affine-relation-inline ::= + `(` input-dims `)` (`[` symbols `]`)? `->` + `(` output-dims `)` : affine-constraint-conjunction + +input-dims ::= bare-id-list +output-dims ::= bare-id-list +symbols ::= bare-id-list + +affine-rel ::= affine-rel-id | affine-relation-inline + +// Usage +affine-rel-spec ::= affine-rel dim-and-symbol-use-list +``` + +All identifiers appearing in input-dims, output-dims, and symbol-dims are +pairwise distinct. All affine-constraint non-terminals in the above syntax are +allowed to contain identifiers only from input-dims, output-dims, and +symbol-dims. + +Affine relations are used to model read, write, may_read, and may_write sets of +functions in the IR. The output dimension identifiers correspond to the data +dimensions. + +Example: + +```mlir +// read relation: two elements ( d0 <= r0 <= d0+1 ) +##aff_rel9 = (d0) -> (r0) : r0 - d0 >= 0, d0 - r0 + 1 >= 0 + +func @count (%A : memref<128xf32>, %pos : i32) -> f32 + reads: {%A ##aff_rel9 (%pos)} + writes: /* empty */ + may_reads: /* empty */ + may_writes: /* empty */ { +bb0 (%0, %1: memref<128xf32>, i64): + %val = affine.load %A [%pos] + %val = affine.load %A [%pos + 1] + %p = mulf %val, %val : f32 + return %p : f32 +} +``` + +### Regions + +#### Making function definition an operation + +MLIR supports values of a Function type. Instead of having first-class IR +concept for functions, one could define an operation with a body region that +defines a function value. The particularity of functions is that their names are +globally visible and can be referred to before being defined, unlike SSA values +that must be defined first. Implementing a "function definition" operation would +require to relax some of the SSA constraints in a region, and also make the IR +Module a region as well. It would also affect the core infrastructure (e.g., +function passes) only for the sake of concept unification. + +#### Having types on a region + +Instead of inspecting the types of arguments of the first block, one could give +the region itself a type. This type would be redundant with block argument +types, which must have values and create room for type mismatches. While +functions do have types that are partly redundant with the arguments of the +first block in the function, this is necessary to support function declarations +that do not have a body which we can refer to in order to obtain the argument +types. A region is always contained in an operation or a function that can be +queried to obtain the “type” of the region if necessary. + +A type on a region can be justified if Regions were to be considered separately +from the enclosing entity (operation or function) and had their own semantics +that should be checked. + +#### Attaching attributes to regions + +Regions could be annotated with dialect attributes to use attribute verification +hooks. An operation could take multiple regions as arguments, and each of them +may require different attributes. However, there are currently very few +practical cases where this would be necessary. Instead, one could simulate +per-region attributes with array attributes attached to the entity containing +the region (operation or function). This decreases the overall complexity of the +IR and enables more concise and op-specific forms, e.g., when all regions of an +op have the same attribute that can be only mentioned once. Since the semantics +of the region is entirely defined by the enclosing entity, it also makes sense +to have attributes attached to that entity rather than to the region itself. + +This can be reconsidered in the future if we see a non-neglectable amount of use +cases. + +### Read/Write/May_Read/May_Write sets for External Functions + +Having read, write, may_read, and may_write sets for external functions which +include opaque ones, high-performance vendor libraries such as CuDNN, CuB, MKL, +FFT libraries, user-provided/optimized functions, or data movement runtimes such +as DMA ones is a powerful feature. It allows the compiler to perform analysis, +composition/transformation in the presence of such calls and with loops around +such calls on sub-tensors. For user-provided or custom hand-tuned functions, the +read/write/may_read/may_write sets could be provided a-priori by a user as part +of the external function signature or they could be part of a database. + +TODO: Design this, and update to use function attribute syntax. + +Example: + +```mlir +##rel9 ( ) [s0] -> (r0, r1) : 0 <= r0 <= 1023, 0 <= r1 <= s0 - 1 + +func @cblas_reduce_ffi(%M: memref<1024 x ? x f32, #layout_map0, /*mem=*/0>) + -> f32 [ + reads: {%M, ##rel9() } + writes: /* empty */ + may_reads: /* empty */ + may_writes: /* empty */ +] + +func @dma_mem_to_scratchpad(%a : memref<1024 x f32, #layout_map0, /*mem=*/0>, + %b : memref<1024 x f32, #layout_map0, 1>, %c : memref<1024 x f32, + #layout_map0>) [ + reads: {%M, ##rel9() } + writes: /* empty */ + may_reads: /* empty */ + may_writes: /* empty */ + ] + +``` + +### Memref Extensions + +1. Arbitrary polyhedral shapes for tensors: e.g., triangular shapes in tensor + dimensions where there is symmetry: use integer set (affine constraints) to + model tensor data space (instead of just extents). Requires some changes to + the IR and the in-memory form. +1. Layout maps + + 1. Allow piece-wise affine maps for layouts: allows clean modeling of + boundary cases for images/tensors through padding, wrapping, mirroring, + padding where padded values are the results of computation as opposed to + data, padding in the interior as opposed to just boundaries. + 1. Allow many-to-one layout maps: Index and layout maps in the current + proposal are bijective. Extending them to many-to-one layout maps allows + cleaner(?) modeling of broadcast/reduce style computations while reusing + memory. + + Proposal 2(a) requires non-trivial changes to the IR and the in-memory + representation. 2(b) requires no change, but impacts how cost models look at + index and layout maps. + +### `affine.if` and `affine.for` Extensions for "Escaping Scalars" + +We considered providing a representation for SSA values that are live out of +`if/else` conditional bodies and loop carried in `affine.for` loops. We +ultimately abandoned this approach due to its complexity. In the current design +of MLIR, scalar variables cannot escape for loops or if instructions. In +situations, where escaping is necessary, we use zero-dimensional tensors and +memrefs instead of scalars. + +**TODO**: This whole section is obsolete and should be updated to use block +arguments and a yield like terminator in for/if instructions. + +The abandoned design of supporting escaping scalars is as follows: + +#### affine.for Instruction + +Syntax: + +``` +[ =] +for % = ... step + [with ] { } +``` + +out-var-list is a comma separated list of SSA values defined in the loop body +and used outside the loop body. in-var-list is a comma separated list of SSA +values used inside the loop body and their initializers. loop-instruction-list +is a list of instructions that may also include a yield instruction. + +Example: + +```mlir +// Return sum of elements in 1-dimensional mref A +func i32 @sum(%A : memref, %N : i32) -> (i32) { + %init = 0 + %result = affine.for %i = 0 to N with %tmp(%init) { + %value = affine.load %A[%i] + %sum = %value + %tmp + yield %sum + } + return %result : i32 +} +``` + +#### affine.if/else Instruction + +Syntax: + +``` + = affine.if () {...} [else {...}] +``` + +Out-var-list is a list of SSA values defined by the if-instruction. The values +are arguments to the yield-instruction that occurs in both then and else clauses +when else clause is present. When if instruction contains only if clause, the +escaping value defined in the then clause should be merged with the value the +variable had before the if instruction. The design captured here does not handle +this situation. + +Example: + +```mlir +// Compute sum of half of the array +func i32 @sum_half(%A : memref, %N : i32) -> (i32) { + %s0 = 0 + %s1 = affine.for %i = 1 ... N step 1 with %s2 (%s0) { + %s3 = if (%i >= %N / 2) { + %v0 = affine.load %A[%i] + %s4 = %s2 + %v0 + yield %s4 + } + yield %s3 + } + return %s1 : i32 +} +``` + +### Multithreading the compiler + +People want compilers to go fast, and one simple way to do that is to +multi-thread them. There are multiple strategies for this, but a simple one is +to optimize and compile separate functions in parallel. LLVM's original pass +manager anticipated this demand, and the CallGraphSCCPass manager is even +designed to support this as well, but unfortunately, a few early design +decisions in LLVM prevent this from ever happening. Instead, things like ThinLTO +are forced to split programs into separate LLVM modules/context and optimize +those chunks independently. + +The problem is that LLVM has several objects in its IR that are globally uniqued +and also mutable: notably constants like `i32 0`. In LLVM, these constants are +`Value`'s, which allow them to be used as operands to instructions, and that +they also have SSA use lists. Because these things are uniqued, every `i32 0` in +any function shares a use list. This means that optimizing multiple functions in +parallel won't work (at least without some sort of synchronization on the use +lists, which would be unbearably inefficient). + +MLIR now supports a multithreaded pass manager. We do this through several +design choices: + +1. MLIR makes use of extensive uniqued immutable data structures (affine + expressions, types, etc are all immutable, uniqued, and immortal). +2. Constants are defined in per-function pools, instead of being globally + uniqued. +3. Functions themselves are not SSA values either, so they don't have the same + problem as constants. +4. FunctionPasses are copied (through their copy ctor) into one instance per + thread, avoiding sharing of local state across threads. + +This allows MLIR function passes to support efficient multithreaded compilation +and code generation. diff --git a/mlir/docs/RationaleSimplifiedPolyhedralForm.md b/mlir/docs/RationaleSimplifiedPolyhedralForm.md new file mode 100644 index 0000000000000000000000000000000000000000..ec2ecc9fe502a1cd64d4d8bca241b6d120e35b89 --- /dev/null +++ b/mlir/docs/RationaleSimplifiedPolyhedralForm.md @@ -0,0 +1,415 @@ +# MLIR: The case for a simplified polyhedral form + +MLIR embraces polyhedral compiler techniques for their many advantages +representing and transforming dense numerical kernels, but it uses a form that +differs significantly from other polyhedral frameworks. + +**Disclaimer / Warning** + +This document is a very early design proposal (which has since been accepted) +that explored the tradeoffs of using this simplified form vs the traditional +polyhedral schedule list form. At some point, this document could be dusted off +and written as a proper academic paper, but until now, it is better to included +it in this crafty form than not to. Beware that this document uses archaic +syntax and should not be considered a canonical reference to modern MLIR. + +## Introduction + +This document discusses general goals of the project, introduces context and the +two alternatives, then talks about the tradeoffs of these designs. Written by +Chris Lattner. + +## General goals of an IR, and goals of mlfunc's specifically + +Our currently planned representation for MLIR consists of two kinds of +functions: an LLVM-like "CFG Function" and an "ML Function": a function +represented in multidimensional loop form. The idea is that a CFG function is +capable of full generality for expressing arbitrary computation, but is awkward +for loop transformations. In contrast, mlfunc's are limited (e.g. to control +flow involving loop nests over affine spaces) but these limitations make it much +easier to transform and analyze, particularly for the set of computations in a +machine learning kernel. + +The design of an intermediate representations is an optimization problem, which +makes intentional tradeoffs that aim to make certain kinds of compiler +transformations simple. After all, it is "possible" to do almost any +transformation on any IR: we could theoretically do loop transformations on +assembly language. OTOH, such transformations would take too long to write, +would be fragile due to irrelevant changes, would be difficult to maintain, and +difficult to make target independent. Performing transformations on the "right +level" of IR makes it much easier to do analysis and transformation of code, and +can make them faster by reducing the size of the IR, and eliminating +possibilities that would have otherwise have to be considered. + +This is the reason we're interested in adding polyhedral techniques to an IR in +the first place: though our base "CFG function" representation is fully capable +of expressing any computation, it is "too" expressive. The limitations imposed +by polyhedral techniques (e.g. on affine loop bounds and array subscripts) +define a closed algebra that can represent an interesting range of +transformations and their compositions, and because of their simplicity, we can +perform (e.g.) dependence analysis more efficiently and more reliably. + +This raises an important question that this document examines: given we are +introducing a redundant and limited way to express code and transformations, +exactly what form is best to perform the analyses and transformations we want? + +We explore two different design points that are capable of expressing the same +class of affine loop computations, but which use different representational +forms. These forms trade off verbosity, ease of transformation, and ease of +analysis in interesting ways. + +## Context: Traditional Polyhedral Form + +We started by discussing a representation that uses the traditional polyhedral +schedule set + domain representation, e.g. consider C-like code like: + +```c + void simple_example(...) { + for (int i = 0; i < N; ++i) { + for (int j = 0; j < N; ++j) { + float tmp = X[i,j] // S1 + A[i,j] = tmp + 1 // S2 + B[i,j] = tmp * 42 // S3 + } + } + } +``` + +The polyhedral representation doesn't care about the actual computation, so we +will abstract them into S1/S2/S3 in the discussion below. Originally, we planned +to represent this with a classical form like (syntax details are not important +and probably slightly incorrect below): + +``` + mlfunc @simple_example(... %N) { + %tmp = call @S1(%X, %i, %j) + domain: (0 <= %i < %N), (0 <= %j < %N) + schedule: (i, j, 0) + + call @S2(%tmp, %A, %i, %j) + domain: (0 <= %i < %N), (0 <= %j < %N) + schedule: (i, j, 1) + + call @S3(%tmp, %B, %i, %j) + domain: (0 <= %i < %N), (0 <= %j < %N) + schedule: (i, j, 2) + } +``` + +In this design, an mlfunc is an unordered bag of instructions whose execution +order is fully controlled by their schedule. + +However, we recently agreed that a more explicit schedule tree representation is +a better fit for our needs, because it exposes important structure that will +make analyses and optimizations more efficient, and also makes the scoping of +SSA values more explicit. This leads us to a representation along the lines of: + +``` + mlfunc @simple_example(... %N) { + d0/d1 = mlspace + for S1(d0), S2(d0), S3(d0) { + for S1(d1), S2(d1), S3(d1) { + + %tmp = call @S1(%X, d0, d1) ;; S1 + domain: (0 <= d0 < %N), (0 <= d1 < %N) + + call @S2(%tmp, %A, d0, d1) ;; S2 + domain: (0 <= d0 < %N), (0 <= d1 < %N) + + call @S3(%tmp, %B, d0, d1) ;; S3 + domain: (0 <= d0 < %N), (0 <= d1 < %N) + } + } + } +``` + +This change makes the nesting structure of the loops an explicit part of the +representation, and makes lexical ordering within a loop significant +(eliminating the constant 0/1/2 of schedules). + +It isn't obvious in the example above, but the representation allows for some +interesting features, including the ability for instructions within a loop nest +to have non-equal domains, like this - the second instruction ignores the outer +10 points inside the loop: + +``` + mlfunc @reduced_domain_example(... %N) { + d0/d1 = mlspace + for S1(d0), S2(d0) { + for S1(d1), S2(d1) { + %tmp = call @S1(%X, d0, d1) ;; S1 + domain: (0 <= d0 < %N), (0 <= d1 < %N) + + call @S2(%tmp, %A, d0, d1) ;; S2 + domain: (10 <= d0 < %N-10), (10 <= d1 < %N-10) + } + } + } +``` + +It also allows schedule remapping within the instruction, like this example that +introduces a diagonal skew through a simple change to the schedules of the two +instructions: + +``` + mlfunc @skewed_domain_example(... %N) { + d0/d1 = mlspace + for S1(d0), S2(d0+d1) { + for S1(d0+d1), S2(d1) { + %tmp = call @S1(%X, d0, d1) ;; S1 + domain: (0 <= d0 < %N), (0 <= d1 < %N) + + call @S2(%tmp, %A, d0, d1) ;; S2 + domain: (0 <= d0 < %N), (0 <= d1 < %N) + } + } + } +``` + +This form has great power, and the polyhedral code generator (which lowers from +an mlfunc to a cfgfunc representation) handles this power so things that +introduce loop transformations don't have to explicitly manipulate the looping +structure. + +## Proposal: Simplified Polyhedral Form + +This document proposes and explores the idea of going one step further, moving +all of the domain and schedule information into the "schedule tree". In this +form, we would have a representation where all instructions inside of a given +for-loop are known to have the same domain, which is maintained by the loop. In +the simplified form, we also have an "if" instruction that takes an affine +condition. + +Our simple example above would be represented as: + +```mlir + mlfunc @simple_example(... %N) { + affine.for %i = 0 ... %N step 1 { + affine.for %j = 0 ... %N step 1 { + // identity noop in this case, but can exist in general. + %0,%1 = affine.apply #57(%i, %j) + + %tmp = call @S1(%X, %0, %1) + + call @S2(%tmp, %A, %0, %1) + + call @S3(%tmp, %B, %0, %1) + } + } + } +``` + +The example with the reduced domain would be represented with an if instruction: + +```mlir + mlfunc @reduced_domain_example(... %N) { + affine.for %i = 0 ... %N step 1 { + affine.for %j = 0 ... %N step 1 { + // identity noop in this case, but can exist in general. + %0,%1 = affinecall #57(%i, %j) + + %tmp = call @S1(%X, %0, %1) + + if (10 <= %i < %N-10), (10 <= %j < %N-10) { + + %2,%3 = affine.apply(%i, %j) // identity noop in this case + + call @S2(%tmp, %A, %2, %3) + } + } + } + } +``` + +These IRs represent exactly the same information, and use a similar information +density. The 'traditional' form introduces an extra level of abstraction +(schedules and domains) that make it easy to transform instructions at the +expense of making it difficult to reason about how those instructions will come +out after code generation. With the simplified form, transformations have to do +parts of code generation inline with their transformation: instead of simply +changing a schedule to **(i+j, j)** to get skewing, you'd have to generate this +code explicitly (potentially implemented by making polyhedral codegen a library +that transformations call into): + +```mlir +mlfunc @skewed_domain_example(... %N) { + affine.for %t1 = 0 ... 2*N-2 step 1 { + affine.for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { + (%i, %j) = (%t1-%t2, %t2) + ... + } + } +} +``` + +## Evaluation + +Both of these forms are capable of expressing the same class of computation: +multidimensional loop nests with affine loop bounds and affine memory +references. That said, they pose very different tradeoffs in other ways. + +### Commonality: can express same computation + +Both of these can express the same sorts of computation, e.g. kernels written in +one form are representable in the other form in all cases. + +### Commonality: dependence analysis + +These representations both use affine functions for data layout mapping and +access subscripts, and dependence analysis works the same way. + +### Commonality: difficulty of determining optimal transformation series + +One major challenge in performance of optimization of this sort of code is +choosing the ordering and behavior of various loop transformations that get +applied. There are non-local effects of every decision, and neither +representation helps solve this inherently hard problem. + +### Commonality: compactness of IR + +In the cases that are most relevant to us (hyper rectangular spaces) these forms +are directly equivalent: a traditional instruction with a limited domain (e.g. +the "reduced_domain_example" above) ends up having one level of ML 'if' inside +its loops. The simplified form pays for this by eliminating schedules and +domains from the IR. Both forms allow code duplication to reduce dynamic +branches in the IR: the traditional approach allows instruction splitting, the +simplified form supports instruction duplication. + +It is important to point out that the traditional form wins on compactness in +the extreme cases: e.g. the loop skewing case. These cases will be rare in +practice for our workloads, and are exactly the cases that downstream +transformations want to be explicit about what they are doing. + +### Simplicity of code generation + +A key final stage of an mlfunc is its conversion to a CFG function, which is +required as part of lowering to the target machine. The simplified form has a +clear advantage here: the IR has a direct correspondence to the structure of the +generated code. + +In contrast, the traditional form has significant complexity in the lowering +process to a CFG function, because the verbosity not imbued in the IR needs to +come out during code generation. Code generation from ISL shows that it is +possible to do this, but it is a non-trivial transformation. + +### Ease of transformation + +An advantage for the traditional form is that it is easier to perform certain +transformations on it: skewing and tiling are just transformations on the +schedule of the instructions in question, it doesn't require changing the loop +structure. + +In practice, the simplified form requires moving the complexity of code +generation into the transformations themselves - this is sometimes trivial, +sometimes involved. The author believes that this should be possible by making +the code generation algorithms themselves be library functions that +transformations call into, instead of an opaque block that happens at the end of +the mlfunc processing. + +Also, the sorts of transformations performed today by XLA (including tiling, +padding, unrolling, and other rectangular transformations) should be easy enough +to implement on either representation. The only cases that are a challenge are +more advanced cases like skewing, e.g. for DMA data movement generation. + +### Ease of analysis: Cost models + +The simplified form is much easier for analyses and transformations to build +cost models for (e.g. answering the question of "how much code bloat will be +caused by unrolling a loop at this level?"), because it is easier to predict +what target code will be generated. With the traditional form, these analyses +will have to anticipate what polyhedral codegen will do to a set of instructions +under consideration: something that is non-trivial in the interesting cases in +question (see "Cost of code generation"). + +### Cost of code generation + +State of the art polyhedral code generation is +[expensive and complicated](https://lirias.kuleuven.be/bitstream/123456789/497238/1/toplas-astgen.pdf), +sometimes exponential time complexity. We expect that most machine learning +workloads will be hyper-rectangular, and thus it should be easy to specialize in +important cases. That said, the traditional polyhedral representation makes it +very easy to introduce complicated and expensive schedules, and provides no way +to understand and project a cost model for using them. All downstream clients of +the IR need to be prepared to handle the full generality of IR that may come to +them. + +The simplified form defines this away: the concepts in the IR remain simple, and +the code much more directly reflects the cost model for lowering to CFG +functions and machine code. This is expected to be very important in the late +stages of a code generator for an accelerator. + +### SSA in ML Functions + +We agree already that values defined in an mlfunc can include scalar values and +they are defined based on traditional dominance. In the simplified form, this is +very simple: arguments and induction variables defined in for-loops are live +inside their lexical body, and linear series of instructions have the same "top +down" dominance relation that a basic block does. + +In the traditional form though, this is not the case: it seems that a lot of +knowledge about how codegen will emit the code is necessary to determine if SSA +form is correct or not. For example, this is invalid code: + +``` + %tmp = call @S1(%X, %0, %1) + domain: (10 <= %i < %N), (0 <= %j < %N) + schedule: (i, j) + + call @S2(%tmp, %A, %0, %1) + domain: (0 <= %i < %N), (0 <= %j < %N) + schedule: (i, j) +``` + +Because `%tmp` isn't defined on some iterations of the %i loop. + +This matters because it makes the verifier more complicated, but more +significantly, it means that load promotion and other optimizations that will +produce SSA form will need to be aware of this and be able to model what codegen +does. + +An emergent property of this that we discussed recently is that PHI nodes in +mlfunc's (if we support them) will also have to have domains. + +### Lack of redundancy in IR + +The traditional form has multiple encodings for the same sorts of behavior: you +end up having bits on `affine.for` loops to specify whether codegen should use +"atomic/separate" policies, unroll loops, etc. Instructions can be split or can +generate multiple copies of their instruction because of overlapping domains, +etc. + +This is a problem for analyses and cost models, because they each have to reason +about these additional forms in the IR. + +### Suitability to purpose: lowering to machine code + +One of the main drivers for this work is lowering to low-level accelerator code, +including two-dimensional vectorization, insertion of DMAs, and other +utilization of the matrix accelerator units. In the author's opinion, the extra +compactness of the traditional form is a negative for this purpose: reasoning +about the generated machine code will require understanding the mapping from +mlfunc to lowered code, which means that it must understand what code generation +will do. + +In the simplified form, the effect of "code generation" is always obvious from +the IR itself, which should make it easier to perform vectorization to target +instructions and other analyses we need to perform. + +## Third Alternative: two different levels of mlfunc + +One hybrid alternative is to support both the traditional and simplified forms +of mlfunc in our IR. + +The stages could look like this, for example: + +1. Early performance transformations could be done on the traditional form. +1. Partial code generation lowers to the simplified form +1. Target specific lowering phases for tiling, and vectorization and other 2D + transforms that don't benefit much from the traditional form could be run. +1. Final codegen to a cfg func can be done when all of the instructions are + replaced with ones valid on the target. + +While this is possible, it isn't clear what would justify the complexity of this +approach. Unless there is a super compelling reason for this, it would be nice +to not do this. **Update:** we discussed this as a design team and agreed that +this wouldn't be a good way to go. diff --git a/mlir/docs/TestingGuide.md b/mlir/docs/TestingGuide.md new file mode 100644 index 0000000000000000000000000000000000000000..723b78bf0f58236c22e7548b5b8de66e2c2dbb47 --- /dev/null +++ b/mlir/docs/TestingGuide.md @@ -0,0 +1,171 @@ +# Testing Guide + +Testing is an integral part of any software infrastructure. In general, all +commits to the MLIR repository should include an accompanying test of some form. +Commits that include no functional changes, such as API changes like symbol +renaming, should be tagged with NFC(no functional changes). This signals to the +reviewer why the change doesn't/shouldn't include a test. + +MLIR generally separates testing into two main categories, [Check](#check-tests) +tests and [Unit](#unit-tests) tests. + +## Check tests + +Check tests are tests that verify that some set of string tags appear in the +output of some program. These tests generally encompass anything related to the +state of the IR (and more); analysis, parsing, transformation, verification, +etc. They are written utilizing several different tools: + +### FileCheck tests + +[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) is a utility tool +that "reads two files (one from standard input, and one specified on the command +line) and uses one to verify the other." Essentially, one file contains a set of +tags that are expected to appear in the output file. MLIR utilizes FileCheck, in +combination with [lit](https://llvm.org/docs/CommandGuide/lit.html), to verify +different aspects of the IR - such as the output of a transformation pass. + +An example FileCheck test is shown below: + +```mlir +// RUN: mlir-opt %s -cse | FileCheck %s + +// CHECK-LABEL: func @simple_constant +func @simple_constant() -> (i32, i32) { + // CHECK-NEXT: %[[RESULT:.*]] = constant 1 + // CHECK-NEXT: return %[[RESULT]], %[[RESULT]] + + %0 = constant 1 : i32 + %1 = constant 1 : i32 + return %0, %1 : i32, i32 +} +``` + +The above test performs a check that after running Common Sub-Expression +elimination, only one constant remains in the IR. + +#### FileCheck best practices + +FileCheck is an extremely useful utility, it allows for easily matching various +parts of the output. This ease of use means that it becomes easy to write +brittle tests that are essentially `diff` tests. FileCheck tests should be as +self-contained as possible and focus on testing the minimal set of +functionalities needed. Let's see an example: + +```mlir +// RUN: mlir-opt %s -cse | FileCheck %s + +// CHECK-LABEL: func @simple_constant() -> (i32, i32) +func @simple_constant() -> (i32, i32) { + // CHECK-NEXT: %result = constant 1 : i32 + // CHECK-NEXT: return %result, %result : i32, i32 + // CHECK-NEXT: } + + %0 = constant 1 : i32 + %1 = constant 1 : i32 + return %0, %1 : i32, i32 +} +``` + +The above example is another way to write the original example shown in the main +[FileCheck tests](#filecheck-tests) section. There are a few problems with this +test; below is a breakdown of the no-nos of this test to specifically highlight +best practices. + +* Tests should be self-contained. + +This means that tests should not test lines or sections outside of what is +intended. In the above example, we see lines such as `CHECK-NEXT: }`. This line +in particular is testing pieces of the Parser/Printer of FuncOp, which is +outside of the realm of concern for the CSE pass. This line should be removed. + +* Tests should be minimal, and only check what is absolutely necessary. + +This means that anything in the output that is not core to the functionality +that you are testing should *not* be present in a CHECK line. This is a separate +bullet just to highlight the importance of it, especially when checking against +IR output. + +If we naively remove the unrelated `CHECK` lines in our source file, we may end +up with: + +```mlir +// CHECK-LABEL: func @simple_constant +func @simple_constant() -> (i32, i32) { + // CHECK-NEXT: %result = constant 1 : i32 + // CHECK-NEXT: return %result, %result : i32, i32 + + %0 = constant 1 : i32 + %1 = constant 1 : i32 + return %0, %1 : i32, i32 +} +``` + +It may seem like this is a minimal test case, but it still checks several +aspects of the output that are unrelated to the CSE transformation. Namely the +result types of the `constant` and `return` operations, as well the actual SSA +value names that are produced. FileCheck `CHECK` lines may contain +[regex statements](https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-regex-matching-syntax) +as well as named +[string substitution blocks](https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-string-substitution-blocks). +Utilizing the above, we end up with the example shown in the main +[FileCheck tests](#filecheck-tests) section. + +```mlir +// CHECK-LABEL: func @simple_constant +func @simple_constant() -> (i32, i32) { + /// Here we use a substitution variable as the output of the constant is + /// useful for the test, but we omit as much as possible of everything else. + // CHECK-NEXT: %[[RESULT:.*]] = constant 1 + // CHECK-NEXT: return %[[RESULT]], %[[RESULT]] + + %0 = constant 1 : i32 + %1 = constant 1 : i32 + return %0, %1 : i32, i32 +} +``` + +### Diagnostic verification tests + +MLIR provides rich source location tracking that can be used to emit errors, +warnings, etc. easily from anywhere throughout the codebase. Certain classes of +tests are written to check that certain diagnostics are emitted for a given +input program, such as an MLIR file. These tests are useful in that they allow +checking specific invariants of the IR without transforming or changing +anything. Some examples of tests in this category are: those that verify +invariants of operations, or check the expected results of an analysis. +Diagnostic verification tests are written utilizing the +[source manager verifier handler](Diagnostics.md#sourcemgr-diagnostic-verifier-handler), +accessible via the `verify-diagnostics` flag in mlir-opt. + +An example .mlir test running under `mlir-opt` is shown below: + +```mlir +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// Expect an error on the same line. +func @bad_branch() { + br ^missing // expected-error {{reference to an undefined block}} +} + +// ----- + +// Expect an error on an adjacent line. +func @foo(%a : f32) { + // expected-error@+1 {{unknown comparison predicate "foo"}} + %result = cmpf "foo", %a, %a : f32 + return +} +``` + +## Unit tests + +Unit tests are written using +[Google Test](https://github.com/google/googletest/blob/master/googletest/docs/primer.md) +and are located in the unittests/ directory. Tests of these form *should* be +limited to API tests that cannot be reasonably written as [Check](#check-tests) +tests, e.g. those for data structures. It is important to keep in mind that the +C++ APIs are not stable, and evolve over time. As such, directly testing the C++ +IR interfaces makes the tests more fragile as those C++ APIs evolve over time. +This makes future API refactorings, which may happen frequently, much more +cumbersome as the number of tests scale. diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md new file mode 100644 index 0000000000000000000000000000000000000000..b233f9bef66dd7d4955f7f9f04f7ef7055f48785 --- /dev/null +++ b/mlir/docs/Traits.md @@ -0,0 +1,246 @@ +# Introduction to MLIR Operation Traits + +[TOC] + +MLIR allows for a truly open operation ecosystem, as any dialect may define +operations that suit a specific level of abstraction. `Traits` are a mechanism +in which to abstract implementation details and properties that are common +across many different operations. `Traits` may be used to specify special +properties and constraints of the operation, including whether the operation has +side effects or whether its output has the same type as the input. Some examples +of traits are `Commutative`, `SingleResult`, `Terminator`, etc. See the more +[comprehensive list](#traits) below for more examples of what is possible. + +## Defining a Trait + +Traits may be defined in C++ by inheriting from the +`OpTrait::TraitBase` class. This base class takes as +template parameters: + +* ConcreteType + - The concrete operation type that this trait was attached to. +* TraitType + - The type of the trait class that is being defined, for use with the + [`Curiously Recurring Template Pattern`](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern). + +A derived trait class is expected to take a single template that corresponds to +the `ConcreteType`. An example trait definition is shown below: + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +}; +``` + +Derived traits may also provide a `verifyTrait` hook, that is called when +verifying the concrete operation. The trait verifiers will currently always be +invoked before the main `Op::verify`. + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +public: + /// Override the 'verifyTrait' hook to add additional verification on the + /// concrete operation. + static LogicalResult verifyTrait(Operation *op) { + // ... + } +}; +``` + +Note: It is generally good practice to define the implementation of the +`verifyTrait` hook out-of-line as a free function when possible to avoid +instantiating the implementation for every concrete operation type. + +### Parametric Traits + +The above demonstrates the definition of a simple self-contained trait. It is +also often useful to provide some static parameters to the trait to control its +behavior. Given that the definition of the trait class is rigid, i.e. we must +have a single template argument for the concrete operation, the templates for +the parameters will need to be split out. An example is shown below: + +```c++ +template +class MyParametricTrait { +public: + template + class Impl : public OpTrait::TraitBase { + // Inside of 'Impl' we have full access to the template parameters + // specified above. + }; +}; +``` + +## Attaching a Trait + +Traits may be used when defining a derived operation type, by simply adding the +name of the trait class to the `Op` class after the concrete operation type: + +```c++ +/// Here we define 'MyOp' along with the 'MyTrait' and `MyParametric trait +/// classes we defined previously. +class MyOp : public Op::Impl> {}; +``` + +To use a trait in the [ODS](OpDefinitions.md) framework, we need to provide a +definition of the trait class. This can be done using the `NativeOpTrait` and +`ParamNativeOpTrait` classes. `ParamNativeOpTrait` provides a mechanism in which +to specify arguments to a parametric trait class with an internal `Impl`. + +```tablegen +// The argument is the c++ trait class name. +def MyTrait : NativeOpTrait<"MyTrait">; + +// The first argument is the parent c++ class name. The second argument is a +// string containing the parameter list. +class MyParametricTrait + : NativeOpTrait<"MyParametricTrait", !cast(!head(parameters))>; +``` + +These can then be used in the `traits` list of an op definition: + +```tablegen +def OpWithInferTypeInterfaceOp : Op<...[MyTrait, MyParametricTrait<10>]> { ... } +``` + +See the documentation on [operation definitions](OpDefinitions.md) for more +details. + +## Using a Trait + +Traits may be used to provide additional methods, static fields, or other +information directly on the concrete operation. `Traits` internally become +`Base` classes of the concrete operation, so all of these are directly +accessible. To expose this information opaquely to transformations and analyses, +[`interfaces`](Interfaces.md) may be used. + +To query if a specific operation contains a specific trait, the `hasTrait<>` +method may be used. This takes as a template parameter the trait class, which is +the same as the one passed when attaching the trait to an operation. + +```c++ +Operation *op = ..; +if (op->hasTrait() || op->hasTrait::Impl>()) + ...; +``` + +## Trait List + +MLIR provides a suite of traits that provide various functionalities that are +common across many different operations. Below is a list of some key traits that +may be used directly by any dialect. The format of the header for each trait +section goes as follows: + +* `Header` + - (`C++ class` -- `ODS class`(if applicable)) + +### Broadcastable + +* `OpTrait::BroadcastableTwoOperandsOneResult` -- `Broadcastable` + +This trait provides the API for operations that are known to have +[broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +operand and result types. Specifically, starting from the most varying +dimension, each dimension pair of the two operands' types should either be the +same or one of them is one. Also, the result type should have the corresponding +dimension equal to the larger one, if known. Shapes are checked partially if +ranks or dimensions are not known. For example, an op with `tensor` and +`tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is +broadcast-compatible. + +Ths trait assumes the op has two operands and one result, and it asserts if the +pre-condition is not satisfied. + +### Commutative + +* `OpTrait::IsCommutative` -- `Commutative` + +This trait adds the property that the operation is commutative, i.e. `X op Y == +Y op X` + +### Function-Like + +* `OpTrait::FunctionLike` + +This trait provides APIs for operations that behave like functions. In +particular: + +- Ops must be symbols, i.e. also have the `Symbol` trait; +- Ops have a single region with multiple blocks that corresponds to the body + of the function; +- the absence of a region corresponds to an external function; +- arguments of the first block of the region are treated as function + arguments; +- they can have argument and result attributes that are stored in dictionary + attributes on the operation itself. + +This trait does *NOT* provide type support for the functions, meaning that +concrete Ops must handle the type of the declared or defined function. +`getTypeAttrName()` is a convenience function that returns the name of the +attribute that can be used to store the function type, but the trait makes no +assumption based on it. + +### HasParent + +* `OpTrait::HasParent` -- `HasParent` + +This trait provides APIs and verifiers for operations that can only be nested +within regions that are attached to operations of `ParentOpType`. + +### IsolatedFromAbove + +* `OpTrait::IsIsolatedFromAbove` -- `IsolatedFromAbove` + +This trait signals that the regions of an operations are known to be isolated +from above. This trait asserts that the regions of an operation will not +capture, or reference, SSA values defined above the region scope. This means +that the following is invalid if `foo.region_op` is defined as +`IsolatedFromAbove`: + +```mlir +%result = constant 10 : i32 +foo.region_op { + foo.yield %result : i32 +} +``` + +This trait is an important structural property of the IR, and enables operations +to have [passes](WritingAPass.md) scheduled under them. + +### NoSideEffect + +* `OpTrait::HasNoSideEffect` -- `NoSideEffect` + +This trait signifies that the operation is pure and has no visible side effects. + +### Single Block with Implicit Terminator + +* `OpTrait::SingleBlockImplicitTerminator` : + `SingleBlockImplicitTerminator` + +This trait provides APIs and verifiers for operations with regions that have a +single block that must terminate with `TerminatorOpType`. + +### Symbol + +* `OpTrait::Symbol` -- `Symbol` + +This trait is used for operations that define a `Symbol`. + +TODO(riverriddle) Link to the proper document detailing the design of symbols. + +### SymbolTable + +* `OpTrait::SymbolTable` -- `SymbolTable` + +This trait is used for operations that define a `SymbolTable`. + +TODO(riverriddle) Link to the proper document detailing the design of symbols. + +### Terminator + +* `OpTrait::IsTerminator` -- `Terminator` + +This trait provides verification and functionality for operations that are known +to be [terminators](LangRef.md#terminator-operations). diff --git a/mlir/docs/Tutorials/Toy/Ch-1.md b/mlir/docs/Tutorials/Toy/Ch-1.md new file mode 100644 index 0000000000000000000000000000000000000000..cb7f97cb3f69e003c714b06546a1aa2f6073406a --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-1.md @@ -0,0 +1,169 @@ +# Chapter 1: Toy Tutorial Introduction + +[TOC] + +This tutorial runs through the implementation of a basic toy language on top of +MLIR. The goal of this tutorial is to introduce the concepts of MLIR; in +particular, how [dialects](../../LangRef.md#dialects) can help easily support +language specific constructs and transformations while still offering an easy +path to lower to LLVM or other codegen infrastructure. This tutorial is based on +the model of the +[LLVM Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html). + +This tutorial assumes you have cloned and built MLIR; if you have not yet done +so, see +[Getting started with MLIR](https://github.com/tensorflow/mlir#getting-started-with-mlir). + +## The Chapters + +This tutorial is divided in the following chapters: + +- [Chapter #1](Ch-1.md): Introduction to the Toy language and the definition + of its AST. +- [Chapter #2](Ch-2.md): Traversing the AST to emit a dialect in MLIR, + introducing base MLIR concepts. Here we show how to start attaching + semantics to our custom operations in MLIR. +- [Chapter #3](Ch-3.md): High-level language-specific optimization using + pattern rewriting system. +- [Chapter #4](Ch-4.md): Writing generic dialect-independent transformations + with Interfaces. Here we will show how to plug dialect specific information + into generic transformations like shape inference and inlining. +- [Chapter #5](Ch-5.md): Partially lowering to lower-level dialects. We'll + convert some our high level language specific semantics towards a generic + affine oriented dialect for optimization. +- [Chapter #6](Ch-6.md): Lowering to LLVM and code generation. Here we'll + target LLVM IR for code generation, and detail more of the lowering + framework. +- [Chapter #7](Ch-7.md): Extending Toy: Adding support for a composite type. + We'll demonstrate how to add a custom type to MLIR, and how it fits in the + existing pipeline. + +## The Language + +This tutorial will be illustrated with a toy language that we’ll call “Toy” +(naming is hard...). Toy is a tensor-based language that allows you to define +functions, perform some math computation, and print results. + +Given that we want to keep things simple, the codegen will be limited to tensors +of rank <= 2, and the only datatype in Toy is a 64-bit floating point type (aka +‘double’ in C parlance). As such, all values are implicitly double precision, +`Values` are immutable (i.e. every operation returns a newly allocated value), +and deallocation is automatically managed. But enough with the long description; +nothing is better than walking through an example to get a better understanding: + +```Toy {.toy} +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + + # b is identical to a, the literal tensor is implicitly reshaped: defining new + # variables is the way to reshape tensors (element count must match). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # transpose() and print() are the only builtin, the following will transpose + # a and b and perform an element-wise multiplication before printing the result. + print(transpose(a) * transpose(b)); +} +``` + +Type checking is statically performed through type inference; the language only +requires type declarations to specify tensor shapes when needed. Functions are +generic: their parameters are unranked (in other words, we know these are +tensors, but we don't know their dimensions). They are specialized for every +newly discovered signature at call sites. Let's revisit the previous example by +adding a user-defined function: + +```Toy {.toy} +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + var a = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <3, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return <3, 2>. + var d = multiply_transpose(b, a); + + # A new call with <3, 2> (instead of <2, 3>) for both dimensions will + # trigger another specialization of `multiply_transpose`. + var e = multiply_transpose(c, d); + + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} +``` + +## The AST + +The AST from the above code is fairly straightforward; here is a dump of it: + +``` +Module: + Function + Proto 'multiply_transpose' @test/ast.toy:5:1' + Args: [a, b] + Block { + Return + BinOp: * @test/ast.toy:6:25 + Call 'transpose' [ @test/ast.toy:6:10 + var: a @test/ast.toy:6:20 + ] + Call 'transpose' [ @test/ast.toy:6:25 + var: b @test/ast.toy:6:35 + ] + } // Block + Function + Proto 'main' @test/ast.toy:9:1' + Args: [] + Block { + VarDecl a<> @test/ast.toy:11:3 + Literal: <2, 3>[<3>[1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[4.000000e+00, 5.000000e+00, 6.000000e+00]] @test/ast.toy:11:17 + VarDecl b<2, 3> @test/ast.toy:12:3 + Literal: <6>[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @test/ast.toy:12:17 + VarDecl c<> @test/ast.toy:15:3 + Call 'multiply_transpose' [ @test/ast.toy:15:11 + var: a @test/ast.toy:15:30 + var: b @test/ast.toy:15:33 + ] + VarDecl d<> @test/ast.toy:18:3 + Call 'multiply_transpose' [ @test/ast.toy:18:11 + var: b @test/ast.toy:18:30 + var: a @test/ast.toy:18:33 + ] + VarDecl e<> @test/ast.toy:21:3 + Call 'multiply_transpose' [ @test/ast.toy:21:11 + var: b @test/ast.toy:21:30 + var: c @test/ast.toy:21:33 + ] + VarDecl f<> @test/ast.toy:24:3 + Call 'multiply_transpose' [ @test/ast.toy:24:11 + Call 'transpose' [ @test/ast.toy:24:30 + var: a @test/ast.toy:24:40 + ] + var: c @test/ast.toy:24:44 + ] + } // Block +``` + +You can reproduce this result and play with the example in the +`examples/toy/Ch1/` directory; try running `path/to/BUILD/bin/toyc-ch1 +test/Examples/Toy/Ch1/ast.toy -emit=ast`. + +The code for the lexer is fairly straightforward; it is all in a single header: +`examples/toy/Ch1/include/toy/Lexer.h`. The parser can be found in +`examples/toy/Ch1/include/toy/Parser.h`; it is a recursive descent parser. If +you are not familiar with such a Lexer/Parser, these are very similar to the +LLVM Kaleidoscope equivalent that are detailed in the first two chapters of the +[Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl02.html). + +The [next chapter](Ch-2.md) will demonstrate how to convert this AST into MLIR. diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md new file mode 100755 index 0000000000000000000000000000000000000000..ce46788f4aefb5af03a412d1b1b7b19063453c46 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-2.md @@ -0,0 +1,577 @@ +# Chapter 2: Emitting Basic MLIR + +[TOC] + +Now that we're familiar with our language and the AST, let's see how MLIR can +help to compile Toy. + +## Introduction: Multi-Level Intermediate Representation + +Other compilers, like LLVM (see the +[Kaleidoscope tutorial](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html)), +offer a fixed set of predefined types and (usually *low-level* / RISC-like) +instructions. It is up to the frontend for a given language to perform any +language-specific type-checking, analysis, or transformation before emitting +LLVM IR. For example, Clang will use its AST to perform not only static analysis +but also transformations, such as C++ template instantiation through AST cloning +and rewrite. Finally, languages with construction at a higher-level than C/C++ +may require non-trivial lowering from their AST to generate LLVM IR. + +As a consequence, multiple frontends end up reimplementing significant pieces of +infrastructure to support the need for these analyses and transformation. MLIR +addresses this issue by being designed for extensibility. As such, there are few +pre-defined instructions (*operations* in MLIR terminology) or types. + +## Interfacing with MLIR + +[Language reference](../../LangRef.md) + +MLIR is designed to be a completely extensible infrastructure; there is no +closed set of attributes (think: constant metadata), operations, or types. MLIR +supports this extensibility with the concept of +[Dialects](../../LangRef.md#dialects). Dialects provide a grouping mechanism for +abstraction under a unique `namespace`. + +In MLIR, [`Operations`](../../LangRef.md#operations) are the core unit of +abstraction and computation, similar in many ways to LLVM instructions. +Operations can have application-specific semantics and can be used to represent +all of the core IR structures in LLVM: instructions, globals (like functions), +modules, etc. + +Here is the MLIR assembly for the Toy `transpose` operations: + +```mlir +%t_tensor = "toy.transpose"(%tensor) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64> loc("example/file/path":12:1) +``` + +Let's break down the anatomy of this MLIR operation: + +- `%t_tensor` + + * The name given to the result defined by this operation (which includes + [a prefixed sigil to avoid collisions](../../LangRef.md#identifiers-and-keywords)). + An operation may define zero or more results (in the context of Toy, we + will limit ourselves to single-result operations), which are SSA values. + The name is used during parsing but is not persistent (e.g., it is not + tracked in the in-memory representation of the SSA value). + +- `"toy.transpose"` + + * The name of the operation. It is expected to be a unique string, with + the namespace of the dialect prefixed before the "`.`". This can be read + as the `transpose` operation in the `toy` dialect. + +- `(%tensor)` + + * A list of zero or more input operands (or arguments), which are SSA + values defined by other operations or referring to block arguments. + +- `{ inplace = true }` + + * A dictionary of zero or more attributes, which are special operands that + are always constant. Here we define a boolean attribute named 'inplace' + that has a constant value of true. + +- `(tensor<2x3xf64>) -> tensor<3x2xf64>` + + * This refers to the type of the operation in a functional form, spelling + the types of the arguments in parentheses and the type of the return + values afterward. + +- `loc("example/file/path":12:1)` + + * This is the location in the source code from which this operation + originated. + +Shown here is the general form of an operation. As described above, the set of +operations in MLIR is extensible. This means that the infrastructure must be +able to opaquely reason about the structure of an operation. This is done by +boiling down the composition of an operation into discrete pieces: + +- A name for the operation. +- A list of SSA operand values. +- A list of [attributes](../../LangRef.md#attributes). +- A list of [types](../../LangRef.md#type-system) for result values. +- A [source location](../../Diagnostics.md#source-locations) for debugging + purposes. +- A list of successors [blocks](../../LangRef.md#blocks) (for branches, + mostly). +- A list of [regions](../../LangRef.md#regions) (for structural operations + like functions). + +In MLIR, every operation has a mandatory source location associated with it. +Contrary to LLVM, where debug info locations are metadata and can be dropped, in +MLIR, the location is a core requirement, and APIs depend on and manipulate it. +Dropping a location is thus an explicit choice which cannot happen by mistake. + +To provide an illustration: If a transformation replaces an operation by +another, that new operation must still have a location attached. This makes it +possible to track where that operation came from. + +It's worth noting that the mlir-opt tool - a tool for testing +compiler passes - does not include locations in the output by default. The +`-mlir-print-debuginfo` flag specifies to include locations. (Run `mlir-opt +--help` for more options.) + +### Opaque API + +MLIR is designed to be a completely extensible system, and as such, the +infrastructure has the capability to opaquely represent all of its core +components: attributes, operations, types, etc. This allows MLIR to parse, +represent, and [round-trip](../../Glossary.md#round-trip) any valid IR. For +example, we could place our Toy operation from above into an `.mlir` file and +round-trip through *mlir-opt* without registering anything: + +```mlir +func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { + %t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64> + return %t_tensor : tensor<3x2xf64> +} +``` + +In the cases of unregistered attributes, operations, and types, MLIR will +enforce some structural constraints (SSA, block termination, etc.), but +otherwise they are completely opaque. This can be useful for bootstrapping +purposes, but it is generally advised against. Opaque operations must be treated +conservatively by transformations and analyses, and they are much harder to +construct and manipulate. + +This handling can be observed by crafting what should be an invalid IR for Toy +and seeing it round-trip without tripping the verifier: + +```mlir +// RUN: toyc %s -emit=mlir + +func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} +``` + +There are multiple problems here: the `toy.print` operation is not a terminator; +it should take an operand; and it shouldn't return any values. In the next +section, we will register our dialect and operations with MLIR, plug into the +verifier, and add nicer APIs to manipulate our operations. + +## Defining a Toy Dialect + +To effectively interface with MLIR, we will define a new Toy dialect. This +dialect will properly model the semantics of the Toy language, as well as +provide an easy avenue for high-level analysis and transformation. + +```c++ +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods, which will be demonstrated in later chapters of the tutorial. +class ToyDialect : public mlir::Dialect { + public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; +``` + +The dialect can now be registered in the global registry: + +```c++ + mlir::registerDialect(); +``` + +Any new `MLIRContext` created from now on will contain an instance of the Toy +dialect and invoke specific hooks for things like parsing attributes and types. + +## Defining Toy Operations + +Now that we have a `Toy` dialect, we can start registering operations. This will +allow for providing semantic information that the rest of the system can hook +into. Let's walk through the creation of the `toy.constant` operation: + +```mlir + %4 = "toy.constant"() {value = dense<1.0> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +``` + +This operation takes zero operands, a +[dense elements](../../LangRef.md#dense-elements-attribute) attribute named +`value`, and returns a single result of +[TensorType](../../LangRef.md#tensor-type). An operation inherits from the +[CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) +`mlir::Op` class which also takes some optional [*traits*](../../Traits.md) to +customize its behavior. These traits may provide additional accessors, +verification, etc. + +```c++ +class ConstantOp : public mlir::Op { + + public: + /// Inherit the constructors from the base Op class. + using Op::Op; + + /// Provide the unique name for this operation. MLIR will use this to register + /// the operation and uniquely identify it throughout the system. + static llvm::StringRef getOperationName() { return "toy.constant"; } + + /// Return the value of the constant by fetching it from the attribute. + mlir::DenseElementsAttr getValue(); + + /// Operations can provide additional verification beyond the traits they + /// define. Here we will ensure that the specific invariants of the constant + /// operation are upheld, for example the result type must be of TensorType. + LogicalResult verify(); + + /// Provide an interface to build this operation from a set of input values. + /// This interface is used by the builder to allow for easily generating + /// instances of this operation: + /// mlir::OpBuilder::create(...) + /// This method populates the given `state` that MLIR uses to create + /// operations. This state is a collection of all of the discrete elements + /// that an operation may contain. + /// Build a constant with the given return type and `value` attribute. + static void build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Type result, mlir::DenseElementsAttr value); + /// Build a constant and reuse the type from the given 'value'. + static void build(mlir::Builder *builder, mlir::OperationState &state, + mlir::DenseElementsAttr value); + /// Build a constant by broadcasting the given 'value'. + static void build(mlir::Builder *builder, mlir::OperationState &state, + double value); +}; +``` + +and we register this operation in the `ToyDialect` constructor: + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx) { + addOperations(); +} +``` + +### Op vs Operation: Using MLIR Operations + +Now that we have defined an operation, we will want to access and transform it. +In MLIR, there are two main classes related to operations: `Operation` and `Op`. +Operation is the actual opaque instance of the operation, and represents the +general API into an operation instance. An `Op` is the base class of a derived +operation, like `ConstantOp`, and acts as smart pointer wrapper around a +`Operation*`. This means that when we define our Toy operations, we are actually +providing a clean interface for building and interfacing with the `Operation` +class; this is why our `ConstantOp` defines no class fields. Therefore, we +always pass these classes around by value, instead of by reference or pointer +(*passing by value* is a common idiom and applies similarly to attributes, +types, etc). We can always get an instance of our toy operation by using LLVM's +casting infrastructure: + +```c++ +void processConstantOp(mlir::Operation *operation) { + ConstantOp op = llvm::dyn_cast(operation); + + // This operation is not an instance of `ConstantOp`. + if (!op) + return; + + // Get the internal operation instance back. + mlir::Operation *internalOperation = op.getOperation(); + assert(internalOperation == operation && + "these operation instances are the same"); +} +``` + +### Using the Operation Definition Specification (ODS) Framework + +In addition to specializing the `mlir::Op` C++ template, MLIR also supports +defining operations in a declarative manner. This is achieved via the +[Operation Definition Specification](../../OpDefinitions.md) framework. Facts +regarding an operation are specified concisely into a TableGen record, which +will be expanded into an equivalent `mlir::Op` C++ template specialization at +compile time. Using the ODS framework is the desired way for defining operations +in MLIR given the simplicity, conciseness, and general stability in the face of +C++ API changes. + +Lets see how to define the ODS equivalent of our ConstantOp: + +The first thing to do is to define a link to the Toy dialect that we defined in +C++. This is used to link all of the operations that we will define to our +dialect: + +```tablegen +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + // The namespace of our dialect, this corresponds 1-1 with the string we + // provided in `ToyDialect::getDialectNamespace`. + let name = "toy"; + + // The C++ namespace that the dialect class definition resides in. + let cppNamespace = "toy"; +} +``` + +Now that we have defined a link to the Toy dialect, we can start defining +operations. Operations in ODS are defined by inheriting from the `Op` class. To +simplify our operation definitions, we will define a base class for operations +in the Toy dialect. + +```tablegen +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; +``` + +With all of the preliminary pieces defined, we can begin to define the constant +operation. + +We define a toy operation by inheriting from our base 'Toy_Op' class above. Here +we provide the mnemonic and a list of traits for the operation. The +[mnemonic](../../OpDefinitions.md#operation-name) here matches the one given in +`ConstantOp::getOperationName` without the dialect prefix; `toy.`. The constant +operation here is also marked as 'NoSideEffect'. This is an ODS trait, and +matches one-to-one with the trait we providing when defining `ConstantOp`: +`mlir::OpTrait::HasNoSideEffect`. Missing here from our C++ definition are the +`ZeroOperands` and `OneResult` traits; these will be automatically inferred +based upon the `arguments` and `results` fields we define later. + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { +} +``` + +At this point you probably might want to know what the C++ code generated by +TableGen looks like. Simply run the `mlir-tblgen` command with the +`gen-op-decls` or the `gen-op-defs` action like so: + +``` +${build_root}/bin/mlir-tblgen -gen-op-defs ${mlir_src_root}/examples/toy/Ch2/include/toy/Ops.td -I ${mlir_src_root}/include/ +``` + +Depending on the selected action, this will print either the `ConstantOp` class +declaration or its implementation. Comparing this output to the hand-crafted +implementation is incredibly useful when getting started with TableGen. + +#### Defining Arguments and Results + +With the shell of the operation defined, we can now provide the +[inputs](../../OpDefinitions.md#operation-arguments) and +[outputs](../../OpDefinitions.md#operation-results) to our operation. The +inputs, or arguments, to an operation may be attributes or types for SSA operand +values. The results correspond to a set of types for the values produced by the +operation: + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); +} +``` + +By providing a name to the arguments or results, e.g. `$value`, ODS will +automatically generate a matching accessor: `DenseElementsAttr +ConstantOp::value()`. + +#### Adding Documentation + +The next step after defining the operation is to document it. Operations may +provide +[`summary` and `description`](../../OpDefinitions.md#operation-documentation) +fields to describe the semantics of the operation. This information is useful +for users of the dialect and can even be used to auto-generate Markdown +documents. + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant operation"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + }]; + + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The generic call operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); +} +``` + +#### Verifying Operation Semantics + +At this point we've already covered a majority of the original C++ operation +definition. The next piece to define is the verifier. Luckily, much like the +named accessor, the ODS framework will automatically generate a lot of the +necessary verification logic based upon the constraints we have given. This +means that we don't need to verify the structure of the return type, or even the +input attribute `value`. In many cases, additional verification is not even +necessary for ODS operations. To add additional verification logic, an operation +can override the [`verifier`](../../OpDefinitions.md#custom-verifier-code) +field. The `verifier` field allows for defining a C++ code blob that will be run +as part of `ConstantOp::verify`. This blob can assume that all of the other +invariants of the operation have already been verified: + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant operation"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + }]; + + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The generic call operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); + + // Add additional verification logic to the constant operation. Here we invoke + // a static `verify` method in a C++ source file. This codeblock is executed + // inside of ConstantOp::verify, so we can use `this` to refer to the current + // operation instance. + let verifier = [{ return ::verify(*this); }]; +} +``` + +#### Attaching `build` Methods + +The final missing component here from our original C++ example are the `build` +methods. ODS can generate some simple build methods automatically, and in this +case it will generate our first build method for us. For the rest, we define the +[`builders`](../../OpDefinitions.md#custom-builder-methods) field. This field +takes a list of `OpBuilder` objects that take a string corresponding to a list +of C++ parameters, as well as an optional code block that can be used to specify +the implementation inline. + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant operation"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + }]; + + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The generic call operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); + + // Add additional verification logic to the constant operation. Here we invoke + // a static `verify` method in a c++ source file. This codeblock is executed + // inside of ConstantOp::verify, so we can use `this` to refer to the current + // operation instance. + let verifier = [{ return ::verify(*this); }]; + + // Add custom build methods for the constant operation. These methods populate + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "DenseElementsAttr value", [{ + // Call into an autogenerated `build` method. + build(builder, result, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. This builder + // creates a declaration for `ConstantOp::build` with the given parameters. + OpBuilder<"Builder *builder, OperationState &result, double value"> + ]; +} +``` + +Above we introduce several of the concepts for defining operations in the ODS +framework, but there are many more that we haven't had a chance to: regions, +variadic operands, etc. Check out the +[full specification](../../OpDefinitions.md) for more details. + +## Complete Toy Example + +At this point we can generate our "Toy IR". A simplified version of the previous +example: + +```.toy +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} +``` + +Results in the following IR: + +```mlir +module { + func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:10) + %1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:25) + %2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:25) + "toy.return"(%2) : (tensor<*xf64>) -> () loc("test/codegen.toy":5:3) + } loc("test/codegen.toy":4:1) + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> loc("test/codegen.toy":9:17) + %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> loc("test/codegen.toy":9:3) + %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> loc("test/codegen.toy":10:17) + %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> loc("test/codegen.toy":10:3) + %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/codegen.toy":11:11) + %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/codegen.toy":12:11) + "toy.print"(%5) : (tensor<*xf64>) -> () loc("test/codegen.toy":13:3) + "toy.return"() : () -> () loc("test/codegen.toy":8:1) + } loc("test/codegen.toy":8:1) +} loc("test/codegen.toy":0:0) +``` + +You can build `toyc-ch2` and try yourself: `toyc-ch2 +test/Examples/Toy/Ch2/codegen.toy -emit=mlir -mlir-print-debuginfo`. We can also +check our RoundTrip: `toyc-ch2 test/Examples/Toy/Ch2/codegen.toy -emit=mlir +-mlir-print-debuginfo 2> codegen.mlir` followed by `toyc-ch2 codegen.mlir +-emit=mlir`. You should also use `mlir-tblgen` on the final definition file and +study the generated C++ code. + +At this point, MLIR knows about our Toy dialect and operations. In the +[next chapter](Ch-3.md), we will leverage our new dialect to implement some +high-level language-specific analyses and transformations for the Toy language. diff --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md new file mode 100644 index 0000000000000000000000000000000000000000..615c2c1bbecc3f67a50734614da906d6af6582e0 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-3.md @@ -0,0 +1,264 @@ +# Chapter 3: High-level Language-Specific Analysis and Transformation + +[TOC] + +Creating a dialect that closely represents the semantics of an input language +enables analyses, transformations and optimizations in MLIR that require +high-level language information and are generally performed on the language AST. +For example, `clang` has a fairly +[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html) +for performing template instantiation in C++. + +We divide compiler transformations into two categories: local and global. In +this chapter, we focus on how to leverage the Toy Dialect and its high-level +semantics to perform local pattern-match transformations that would be difficult +in LLVM. For this, we use MLIR's +[Generic DAG Rewriter](../../GenericDAGRewriter.md). + +There are two methods that can be used to implement pattern-match +transformations: 1. Imperative, C++ pattern-match and rewrite 2. Declarative, +rule-based pattern-match and rewrite using table-driven +[Declarative Rewrite Rules](../../DeclarativeRewrites.md) (DRR). Note that the +use of DRR requires that the operations be defined using ODS, as described in +[Chapter 2](Ch-2.md). + +# Optimize Transpose using C++ style pattern-match and rewrite + +Let's start with a simple pattern and try to eliminate a sequence of two +transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the +corresponding Toy example: + +```Toy(.toy) +def transpose_transpose(x) { + return transpose(transpose(x)); +} +``` + +Which corresponds to the following IR: + +```mlir +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%1) : (tensor<*xf64>) -> () +} +``` + +This is a good example of a transformation that is trivial to match on the Toy +IR but that would be quite hard for LLVM to figure. For example, today Clang +can't optimize away the temporary array, and the computation with the naive +transpose is expressed with these loops: + +```c++ +#define N 100 +#define M 100 + +void sink(void *); +void double_transpose(int A[N][M]) { + int B[M][N]; + for(int i = 0; i < N; ++i) { + for(int j = 0; j < M; ++j) { + B[j][i] = A[i][j]; + } + } + for(int i = 0; i < N; ++i) { + for(int j = 0; j < M; ++j) { + A[i][j] = B[j][i]; + } + } + sink(A); +} +``` + +For a simple C++ approach to rewrite involving matching a tree-like pattern in +the IR and replacing it with a different set of operations, we can plug into the +MLIR `Canonicalizer` pass by implementing a `RewritePattern`: + +```c++ +/// Fold transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method is attempting to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; +``` + +The implementation of this rewriter is in `ToyCombine.cpp`. The +[canonicalization pass](../../Canonicalization.md) applies transformations +defined by operations in a greedy, iterative manner. To ensure that the +canonicalization pass applies our new transform, we set +[hasCanonicalizer = 1](../../OpDefinitions.md#hascanonicalizer) and register the +pattern with the canonicalization framework. + +```c++ +// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} +``` + +We also need to update our main file, `toyc.cpp`, to add an optimization +pipeline. In MLIR, the optimizations are run through a `PassManager` in a +similar way to LLVM: + +```c++ + mlir::PassManager pm(module.getContext()); + pm.addNestedPass(mlir::createCanonicalizerPass()); +``` + +Finally, we can run `toyc-ch3 test/transpose_transpose.toy -emit=mlir -opt` and +observe our pattern in action: + +```mlir +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%arg0) : (tensor<*xf64>) -> () +} +``` + +As expected, we now directly return the function argument, bypassing any +transpose operation. However, one of the transposes still hasn't been +eliminated. That is not ideal! What happened is that our pattern replaced the +last transform with the function input and left behind the now dead transpose +input. The Canonicalizer knows to clean up dead operations; however, MLIR +conservatively assumes that operations may have side-effects. We can fix this by +adding a new trait, `NoSideEffect`, to our `TransposeOp`: + +```tablegen: +def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...} +``` + +Let's retry now `toyc-ch3 test/transpose_transpose.toy -emit=mlir -opt`: + +```mlir +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> { + "toy.return"(%arg0) : (tensor<*xf64>) -> () +} +``` + +Perfect! No `transpose` operation is left - the code is optimal. + +In the next section, we use DRR for pattern match optimizations associated with +the Reshape op. + +# Optimize Reshapes using DRR + +Declarative, rule-based pattern-match and rewrite (DRR) is an operation +DAG-based declarative rewriter that provides a table-based syntax for +pattern-match and rewrite rules: + +```tablegen: +class Pattern< + dag sourcePattern, list resultPatterns, + list additionalConstraints = [], + dag benefitsAdded = (addBenefit 0)>; +``` + +A redundant reshape optimization similar to SimplifyRedundantTranspose can be +expressed more simply using DRR as follows: + +```tablegen: +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; +``` + +The automatically generated C++ code corresponding to each of the DRR patterns +can be found under path/to/BUILD/projects/mlir/examples/toy/Ch3/ToyCombine.inc. + +DRR also provides a method for adding argument constraints when the +transformation is conditional on some properties of the arguments and results. +An example is a transformation that eliminates reshapes when they are redundant, +i.e. when the input and output shapes are identical. + +```tablegen: +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; +``` + +Some optimizations may require additional transformations on instruction +arguments. This is achieved using NativeCodeCall, which allows for more complex +transformations either by calling into a C++ helper function or by using inline +C++. An example of such an optimization is FoldConstantReshape, where we +optimize Reshape of a constant value by reshaping the constant in place and +eliminating the reshape operation. + +```tablegen: +def ReshapeConstant : NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; +``` + +We demonstrate these reshape optimizations using the following +trivialReshape.toy program: + +```c++ +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} +``` + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>} + : () -> tensor<2xf64> + %1 = "toy.reshape"(%0) : (tensor<2xf64>) -> tensor<2x1xf64> + %2 = "toy.reshape"(%1) : (tensor<2x1xf64>) -> tensor<2x1xf64> + %3 = "toy.reshape"(%2) : (tensor<2x1xf64>) -> tensor<2x1xf64> + "toy.print"(%3) : (tensor<2x1xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +We can try to run `toyc-ch3 test/trivialReshape.toy -emit=mlir -opt` and observe +our pattern in action: + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00], [2.000000e+00]]> \ + : tensor<2x1xf64>} : () -> tensor<2x1xf64> + "toy.print"(%0) : (tensor<2x1xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +As expected, no reshape operations remain after canonicalization. + +Further details on the declarative rewrite method can be found at +[Table-driven Declarative Rewrite Rule (DRR)](../../DeclarativeRewrites.md). + +In this chapter, we saw how to use certain core transformations through always +available hooks. In the [next chapter](Ch-4.md), we will see how to use generic +solutions that scale better through Interfaces. diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md new file mode 100644 index 0000000000000000000000000000000000000000..4a4e11c68e608aeea929cd40982d157ffbc39265 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -0,0 +1,387 @@ +# Chapter 4: Enabling Generic Transformation with Interfaces + +[TOC] + +## Background: Grappling with an Extensible IR + +Through dialects, MLIR allows for the representation of many different levels of +abstraction; the Toy dialect that we have previously defined is one such +example. Though these different dialects may represent different abstractions, +there is often a set of common transformations and analyses that we would like +to perform. The problem that arises is that naively implementing each +transformation for each dialect leads to large amounts of code duplication, as +the internal algorithms are generally very similar, if not the same. We would +like to provide the ability for transformations to opaquely hook into dialects +like Toy to get the information they need. + +MLIR provides a set of always available-hooks for certain core transformations, +as seen in the [previous chapter](Ch-3.md), where we registered some +canonicalizations via a hook on our operations (`getCanonicalizationPatterns`). +However, these types of hooks don't really scale well. Therefore, a more generic +solution was designed, in the form of [interfaces](../../Interfaces.md), to make +the MLIR infrastructure as extensible as the representation. Interfaces provide +a generic mechanism for dialects and operations to provide information to a +transformation or analysis. + +## Shape Inference: Preparing for Code Generation + +Our Toy IR currently operates on generic tensors, meaning that we don't know the +shape of tensors other than during the initialization of constants. This +complicates optimizations, as well as code generation. Fortunately, we can +simply propagate the shapes through the computation until they are all known. +The issue is how to handle calls to user-defined generic functions: every call +site could deduce different shapes. One possibility would be to perform symbolic +inference based on the argument types, but this would be hard to generalize if +we were to introduce more control flow in the language. Another approach would +be function specialization, where every call site with new argument shapes +duplicates the called function and specializes it. The approach we take for Toy +is to inline all of the function calls, then perform intraprocedural shape +propagation. + +### Inlining + +Here we could write an inlining algorithm specifically designed for the Toy +dialect, but that can become quite complicated depending on the level of +complexity that we want. Disregarding cost modeling, the pure structural +transformation is already complex to implement from scratch. Thankfully, MLIR +provides a generic inliner algorithm that dialects can plug into. All we need to +do in Toy is to provide the [interfaces](../../Interfaces.md) for the inliner to +hook into. + +The first thing we need to do is to define the constraints on inlining +operations in the Toy dialect. This information is provided through a +[dialect interface](../../Interfaces.md#dialect-interfaces). This is essentially +a class containing a set of virtual hooks for which a dialect may provide a +specialization. In this case, the interface is `DialectInlinerInterface`. + +```c++ +/// This class defines the interface for handling inlining with Toy operations. +/// We simplify inherit from the base interface class and provide a +/// specialization of the necessary methods. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// This hook checks to see if the given operation is legal to inline into the + /// given region. For Toy this hook can simply return true, as all Toy + /// operations are inlinable. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + /// This hook is called when a terminator operation has been inlined. The only + /// terminator that we have in the Toy dialect is the return + /// operation(toy.return). We handle the return by replacing the values + /// previously returned by the call operation with the operands of the + /// return. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } +}; +``` + +We then register our dialect interface directly on the Toy dialect, similarly to +how we did for operations. + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addInterfaces(); +} +``` + +Next, we need to provide a way for the inliner to know that `toy.generic_call` +represents a call to a function. MLIR provides an +[operation interface](../../Interfaces.md#operation-interfaces) that can be used +to mark an operation as being "call-like". Unlike dialect interfaces, operation +interfaces provide a more refined granularity of information that is specific +and core to a single operation. The interface that we will be adding here is the +`CallOpInterface`. + +To add this interface we just need to include the definition into our operation +specification file (`Ops.td`): + +```tablegen +#ifdef MLIR_CALLINTERFACES +#else +include "mlir/Analysis/CallInterfaces.td" +#endif // MLIR_CALLINTERFACES +``` + +and add it to the traits list of `GenericCallOp`: + +```tablegen +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + ... +} +``` + +In the above we also use the `DeclareOpInterfaceMethods` directive to +auto-declare all of the interface methods in the class declaration of +GenericCallOp. This means that we just need to provide a definition: + +```c++ +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } +``` + +Now that the inliner has been informed about the Toy dialect, we can add the +inliner pass to the pass manager for Toy: + +```c++ + pm.addPass(mlir::createInlinerPass()); +``` + +Now let's look at a working example: + +```mlir +func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> + %2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%2) : (tensor<*xf64>) -> () +} +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> + %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> + %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> + %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} +``` + +We have two calls to multiple_transpose that we would like to inline into main, +but if we look at the output nothing has changed. We are missing one last subtle +piece: there is a hidden type conversion on the edge of the call. If we look at +the above, the operands to the generic_call are of type `tensor<2x3xf64>`, while +the inputs to the function expect `tensor<*xf64>`. To resolve this difference, +the inliner expects an explicit cast operation to be inserted. For this, we need +to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent +casts between two different shapes. + +```tablegen +def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} +``` + +We can then override the necessary hook on the ToyInlinerInterface to insert +this for us when necessary: + +```c++ +struct ToyInlinerInterface : public DialectInlinerInterface { + ... + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; +``` + +If we run the working example through the pipeline again, we get the expected: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.cast"(%1) : (tensor<2x3xf64>) -> tensor<*xf64> + %3 = "toy.cast"(%0) : (tensor<2x3xf64>) -> tensor<*xf64> + %4 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64> + %5 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64> + %6 = "toy.mul"(%4, %5) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.print"(%6) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} +``` + +NOTE: The generic inliner will also perform simplifications, so the output may +be a bit cleaner than expected. + +### Intraprocedural Shape Inference + +Now that we have inlined all of the functions, we are left with a main function +containing a mix of static and dynamically shaped operations. We can now write a +simple shape inference pass to propagate shapes intraprocedurally (within a +single function). We could write this as a pass that directly encodes the +constraints of the operations within the Toy dialect, but this seems like a good +candidate for a transformation that could be written generically. As a good rule +of thumb, it is best to express a transformation as generically as possible, +such that it can be extended to other dialects in the future. There is no +telling how many other dialects may have similar needs or encounter the same +problems. + +For shape inference, if we break down the problem to its core, we really just +want operations to tell us the expected outputs given a set of statically known +inputs. (We can definitely get more complex than that, but for our needs we can +keep it simple.) Given that this property is core to a specific operation, we +can define an operation interface that can be specified on operations that need +to have their result shapes inferred. + +Similarly to operations, we can also +[define operation interfaces](../../OpDefinitions.md#operation-interfaces) using +the operation definition specification (ODS) framework. + +The interface is defined by inheriting from `OpInterface`, which takes the name +to be given to the generated C++ interface class as a template argument. For our +purposes, we will name the generated class a simpler `ShapeInference`. We also +provide a description for the interface. + +```tablegen +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; +} +``` + +Next, we define the interface methods that the operations will need to provide. +An interface method is comprised of: a description; a C++ return type in string +form; a method name in string form; and a few optional components, depending on +the need. See the +[ODS documentation](../../OpDefinitions.md#operation-interfaces) for more +information. + +```tablegen +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} +``` + +Now that the interface is defined, we can add it to the necessary Toy operations +in a similar way to how we added the `CallOpInterface` to the GenericCallOp: + +``` +def MulOp : Toy_Op<"mul", + [..., DeclareOpInterfaceMethods]> { + ... +} +``` + +Each of these operations will then need to provide a definition for the +`inferShapes()` method. As an example, for the mul op, the result shape is +inferred as the shape of the inputs. + +```c++ +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } +``` + +At this point, each of the necessary Toy operations provide a mechanism by which +to infer their output shapes. The ShapeInferencePass is a FunctionPass: it will +runs on each Function in isolation. MLIR also supports general +[OperationPasses](../../WritingAPass.md#operation-pass) that run on any isolated +operation (i.e. other function-like operations), but here our module only +contains functions, so there is no need to generalize to all operations. + +Implementing such a pass is done by creating a class inheriting from +`mlir::FunctionPass` and overriding the `runOnFunction()` method: + +```c++ +class ShapeInferencePass : public mlir::FunctionPass { + void runOnFunction() override { + FuncOp function = getFunction(); + ... + } +}; +``` + +The algorithm operates as follows: + +1. Build a worklist containing all the operations that return a dynamically + shaped tensor: these are the operations that need shape inference. +2. Iterate on the worklist: + - find an operation to process: the next ready operation in the worklist + has all of its arguments non-generic, + - if no operation is found, break out of the loop, + - remove the operation from the worklist, + - infer the shape of its output from the argument types. +3. If the worklist is empty, the algorithm succeeded. + +When processing an operation, we query if it registered the `ShapeInference` +interface. + +```c++ + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + + /// We check if an operation has a particular interface by casting. + if (ShapeInference shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } +``` + +We can then add our pass to the pass manager: + +```c++ + pm.addPass(mlir::createShapeInferencePass()); +``` + +If we rerun our original example, we now get the following: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%2) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} +``` + +You can build `toyc-ch4` and try yourself: `toyc-ch4 +test/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt`. + +In the [next chapter](Ch-5.md), we will start the process of code generation by +targeting a lower level dialect for optimizing some of the more compute-heavy +Toy operations. diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md new file mode 100644 index 0000000000000000000000000000000000000000..8a4268b498fa5c3c039c1745f74dcbfc855bb74f --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -0,0 +1,357 @@ +# Chapter 5: Partial Lowering to Lower-Level Dialects for Optimization + +[TOC] + +At this point, we are eager to generate actual code and see our Toy language +take life. We will use LLVM to generate code, but just showing the LLVM builder +interface here wouldn't be very exciting. Instead, we will show how to perform +progressive lowering through a mix of dialects coexisting in the same function. + +To make it more interesting, in this chapter we will consider that we want to +reuse existing optimizations implemented in a dialect optimizing affine +transformations: `Affine`. This dialect is tailored to the computation-heavy +part of the program and is limited: it doesn't support representing our +`toy.print` builtin, for instance, neither should it! Instead, we can target +`Affine` for the computation heavy part of Toy, and in the +[next chapter](Ch-6.md) directly the `LLVM IR` dialect for lowering `print`. As +part of this lowering, we will be lowering from the +[TensorType](../../LangRef.md#tensor-type) that `Toy` operates on to the +[MemRefType](../../LangRef.md#memref-type) that is indexed via an affine +loop-nest. Tensors represent an abstract value-typed sequence of data, meaning +that they don't live in any memory. MemRefs, on the other hand, represent lower +level buffer access, as they are concrete references to a region of memory. + +# Dialect Conversions + +MLIR has many different dialects, so it is important to have a unified framework +for [converting](../../Glossary.md#conversion) between them. This is where the +`DialectConversion` framework comes into play. This framework allows for +transforming a set of `illegal` operations to a set of `legal` ones. To use this +framework, we need to provide two things (and an optional third): + +* A [Conversion Target](../../DialectConversion.md#conversion-target) + + - This is the formal specification of what operations or dialects are + legal for the conversion. Operations that aren't legal will require + rewrite patterns to perform + [legalization](./../../Glossary.md#legalization). + +* A set of + [Rewrite Patterns](../../DialectConversion.md#rewrite-pattern-specification) + + - These are the set of [patterns](../../QuickstartRewrites.md) used to + convert `illegal` operations into a set of zero or more `legal` ones. + +* Optionally, a [Type Converter](../../DialectConversion.md#type-conversion). + + - If provided, this is used to convert the types of block arguments. We + won't be needing this for our conversion. + +## Conversion Target + +For our purposes, we want to convert the compute-intensive `Toy` operations into +a combination of operations from the `Affine` `Standard` dialects for further +optimization. To start off the lowering, we first define our conversion target: + +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + mlir::ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + ... +} +``` + +## Conversion Patterns + +After the conversion target has been defined, we can define how to convert the +`illegal` operations into `legal` ones. Similarly to the canonicalization +framework introduced in [chapter 3](Ch-3.md), the +[`DialectConversion` framework](../../DialectConversion.md) also uses +[RewritePatterns](../../QuickstartRewrites.md) to perform the conversion logic. +These patterns may be the `RewritePatterns` seen before or a new type of pattern +specific to the conversion framework `ConversionPattern`. `ConversionPatterns` +are different from traditional `RewritePatterns` in that they accept an +additional `operands` parameter containing operands that have been +remapped/replaced. This is used when dealing with type conversions, as the +pattern will want to operate on values of the new type but match against the +old. For our lowering, this invariant will be useful as it translates from the +[TensorType](../../LangRef.md#tensor-type) currently being operated on to the +[MemRefType](../../LangRef.md#memref-type). Let's look at a snippet of lowering +the `toy.transpose` operation: + +```c++ +/// Lower the `toy.transpose` operation to an affine loop nest. +struct TransposeOpLowering : public mlir::ConversionPattern { + TransposeOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {} + + /// Match and rewrite the given `toy.transpose` operation, with the given + /// operands that have been remapped from `tensor<...>` to `memref<...>`. + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + // Call to a helper function that will lower the current operation to a set + // of affine loops. We provide a functor that operates on the remapped + // operands, as well as the loop induction variables for the inner most + // loop body. + lowerOpToLoops( + op, operands, rewriter, + [loc](mlir::PatternRewriter &rewriter, + ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. This adaptor is automatically provided by the ODS + // framework. + TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + mlir::Value input = transposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; +``` + +Now we can prepare the list of patterns to use during the lowering process: + +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + ... + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + mlir::OwningRewritePatternList patterns; + patterns.insert<..., TransposeOpLowering>(&getContext()); + + ... +``` + +## Partial Lowering + +Once the patterns have been defined, we can perform the actual lowering. The +`DialectConversion` framework provides several different modes of lowering, but, +for our purposes, we will perform a partial lowering, as we will not convert +`toy.print` at this time. + +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + mlir::ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + mlir::OwningRewritePatternList patterns; + patterns.insert<..., TransposeOpLowering>(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + auto function = getFunction(); + if (mlir::failed(mlir::applyPartialConversion(function, target, patterns))) + signalPassFailure(); +} +``` + +### Design Considerations With Partial Lowering + +Before diving into the result of our lowering, this is a good time to discuss +potential design considerations when it comes to partial lowering. In our +lowering, we transform from a value-type, TensorType, to an allocated +(buffer-like) type, MemRefType. However, given that we do not lower the +`toy.print` operation, we need to temporarily bridge these two worlds. There are +many ways to go about this, each with their own tradeoffs: + +* Generate `load` operations from the buffer + +One option is to generate `load` operations from the buffer type to materialize +an instance of the value type. This allows for the definition of the `toy.print` +operation to remain unchanged. The downside to this approach is that the +optimizations on the `affine` dialect are limited, because the `load` will +actually involve a full copy that is only visible *after* our optimizations have +been performed. + +* Generate a new version of `toy.print` that operates on the lowered type + +Another option would be to have another, lowered, variant of `toy.print` that +operates on the lowered type. The benefit of this option is that there is no +hidden, unnecessary copy to the optimizer. The downside is that another +operation definition is needed that may duplicate many aspects of the first. +Defining a base class in [ODS](../../OpDefinitions.md) may simplify this, but +you still need to treat these operations separately. + +* Update `toy.print` to allow for operating on the lowered type + +A third option is to update the current definition of `toy.print` to allow for +operating the on the lowered type. The benefit of this approach is that it is +simple, does not introduce an additional hidden copy, and does not require +another operation definition. The downside to this option is that it requires +mixing abstraction levels in the `Toy` dialect. + +For the sake of simplicity, we will use the third option for this lowering. This +involves updating the type constraints on the PrintOp in the operation +definition file: + +```tablegen +def PrintOp : Toy_Op<"print"> { + ... + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} +``` + +## Complete Toy Example + +Looking back at our current working example: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%3) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} +``` + +With affine lowering added to our pipeline, we can now generate: + +```mlir +func @main() { + %cst = constant 1.000000e+00 : f64 + %cst_0 = constant 2.000000e+00 : f64 + %cst_1 = constant 3.000000e+00 : f64 + %cst_2 = constant 4.000000e+00 : f64 + %cst_3 = constant 5.000000e+00 : f64 + %cst_4 = constant 6.000000e+00 : f64 + + // Allocating buffers for the inputs and outputs. + %0 = alloc() : memref<3x2xf64> + %1 = alloc() : memref<3x2xf64> + %2 = alloc() : memref<2x3xf64> + + // Initialize the input buffer with the constant values. + affine.store %cst, %2[0, 0] : memref<2x3xf64> + affine.store %cst_0, %2[0, 1] : memref<2x3xf64> + affine.store %cst_1, %2[0, 2] : memref<2x3xf64> + affine.store %cst_2, %2[1, 0] : memref<2x3xf64> + affine.store %cst_3, %2[1, 1] : memref<2x3xf64> + affine.store %cst_4, %2[1, 2] : memref<2x3xf64> + + // Load the transpose value from the input buffer and store it into the + // next input buffer. + affine.for %arg0 = 0 to 3 { + affine.for %arg1 = 0 to 2 { + %3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64> + affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64> + } + } + + // Multiply and store into the output buffer. + affine.for %arg0 = 0 to 2 { + affine.for %arg1 = 0 to 3 { + %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> + %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> + %5 = mulf %3, %4 : f64 + affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64> + } + } + + // Print the value held by the buffer. + "toy.print"(%0) : (memref<3x2xf64>) -> () + dealloc %2 : memref<2x3xf64> + dealloc %1 : memref<3x2xf64> + dealloc %0 : memref<3x2xf64> + return +} +``` + +## Taking Advantage of Affine Optimization + +Our naive lowering is correct, but it leaves a lot to be desired with regards to +efficiency. For example, the lowering of `toy.mul` has generated some redundant +loads. Let's look at how adding a few existing optimizations to the pipeline can +help clean this up. Adding the `LoopFusion` and `MemRefDataFlowOpt` passes to +the pipeline gives the following result: + +```mlir +func @main() { + %cst = constant 1.000000e+00 : f64 + %cst_0 = constant 2.000000e+00 : f64 + %cst_1 = constant 3.000000e+00 : f64 + %cst_2 = constant 4.000000e+00 : f64 + %cst_3 = constant 5.000000e+00 : f64 + %cst_4 = constant 6.000000e+00 : f64 + + // Allocating buffers for the inputs and outputs. + %0 = alloc() : memref<3x2xf64> + %1 = alloc() : memref<2x3xf64> + + // Initialize the input buffer with the constant values. + affine.store %cst, %1[0, 0] : memref<2x3xf64> + affine.store %cst_0, %1[0, 1] : memref<2x3xf64> + affine.store %cst_1, %1[0, 2] : memref<2x3xf64> + affine.store %cst_2, %1[1, 0] : memref<2x3xf64> + affine.store %cst_3, %1[1, 1] : memref<2x3xf64> + affine.store %cst_4, %1[1, 2] : memref<2x3xf64> + + affine.for %arg0 = 0 to 3 { + affine.for %arg1 = 0 to 2 { + // Load the transpose value from the input buffer. + %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> + + // Multiply and store into the output buffer. + %3 = mulf %2, %2 : f64 + affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64> + } + } + + // Print the value held by the buffer. + "toy.print"(%0) : (memref<3x2xf64>) -> () + dealloc %1 : memref<2x3xf64> + dealloc %0 : memref<3x2xf64> + return +} +``` + +Here, we can see that a redundant allocation was removed, the two loop nests +were fused, and some unnecessary `load`s were removed. You can build `toyc-ch5` +and try yourself: `toyc-ch5 test/lowering.toy -emit=mlir-affine`. We can also +check our optimizations by adding `-opt`. + +In this chapter we explored some aspects of partial lowering, with the intent to +optimize. In the [next chapter](Ch-6.md) we will continue the discussion about +dialect conversion by targeting LLVM for code generation. diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md new file mode 100644 index 0000000000000000000000000000000000000000..939b2b4f776476179d232f595f5cdad4381b4668 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -0,0 +1,323 @@ +# Chapter 6: Lowering to LLVM and CodeGeneration + +[TOC] + +In the [previous chapter](Ch-5.md), we introduced the +[dialect conversion](../../DialectConversion.md) framework and partially lowered +many of the `Toy` operations to affine loop nests for optimization. In this +chapter, we will finally lower to LLVM for code generation. + +# Lowering to LLVM + +For this lowering, we will again use the dialect conversion framework to perform +the heavy lifting. However, this time, we will be performing a full conversion +to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already +lowered all but one of the `toy` operations, with the last being `toy.print`. +Before going over the conversion to LLVM, let's lower the `toy.print` operation. +We will lower this operation to a non-affine loop nest that invokes `printf` for +each element. Note that, because the dialect conversion framework supports +[transitive lowering](Glossary.md#transitive-lowering), we don't need to +directly emit operations in the LLVM dialect. By transitive lowering, we mean +that the conversion framework may apply multiple patterns to fully legalize an +operation. In this example, we are generating a structured loop nest instead of +the branch-form in the LLVM dialect. As long as we then have a lowering from the +loop operations to LLVM, the lowering will still succeed. + +During lowering we can get, or build, the declaration for printf as so: + +```c++ +/// Return a symbol reference to the printf function, inserting it into the +/// module if necessary. +static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get("printf", context); + + // Create a function declaration for printf, the signature is: + // * `i32 (i8*, ...)` + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", llvmFnType); + return SymbolRefAttr::get("printf", context); +} +``` + +Now that the lowering for the printf operation has been defined, we can specify +the components necessary for the lowering. These are largely the same as the +components defined in the [previous chapter](Ch-5.md). + +## Conversion Target + +For this conversion, aside from the top-level module, we will be lowering +everything to the LLVM dialect. + +```c++ + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); +``` + +## Type Converter + +This lowering will also transform the MemRef types which are currently being +operated on into a representation in LLVM. To perform this conversion, we use a +TypeConverter as part of the lowering. This converter specifies how one type +maps to another. This is necessary now that we are performing more complicated +lowerings involving block arguments. Given that we don't have any +Toy-dialect-specific types that need to be lowered, the default converter is +enough for our use case. + +```c++ + LLVMTypeConverter typeConverter(&getContext()); +``` + +## Conversion Patterns + +Now that the conversion target has been defined, we need to provide the patterns +used for lowering. At this point in the compilation process, we have a +combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and +`affine` dialects already provide the set of patterns needed to transform them +into LLVM dialect. These patterns allow for lowering the IR in multiple stages +by relying on [transitive lowering](Glossary.md#transitive-lowering). + +```c++ + mlir::OwningRewritePatternList patterns; + mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); + mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.insert(&getContext()); +``` + +## Full Lowering + +We want to completely lower to LLVM, so we use a `FullConversion`. This ensures +that only legal operations will remain after the conversion. + +```c++ + mlir::ModuleOp module = getModule(); + if (mlir::failed(mlir::applyFullConversion(module, target, patterns, + &typeConverter))) + signalPassFailure(); +``` + +Looking back at our current working example: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%3) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} +``` + +We can now lower down to the LLVM dialect, which produces the following code: + +```mlir +llvm.func @free(!llvm<"i8*">) +llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32 +llvm.func @malloc(!llvm.i64) -> !llvm<"i8*"> +llvm.func @main() { + %0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double + %1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double + + ... + +^bb16: + %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %222 = llvm.mlir.constant(0 : index) : !llvm.i64 + %223 = llvm.mlir.constant(2 : index) : !llvm.i64 + %224 = llvm.mul %214, %223 : !llvm.i64 + %225 = llvm.add %222, %224 : !llvm.i64 + %226 = llvm.mlir.constant(1 : index) : !llvm.i64 + %227 = llvm.mul %219, %226 : !llvm.i64 + %228 = llvm.add %225, %227 : !llvm.i64 + %229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*"> + %230 = llvm.load %229 : !llvm<"double*"> + %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 + %232 = llvm.add %219, %218 : !llvm.i64 + llvm.br ^bb15(%232 : !llvm.i64) + + ... + +^bb18: + %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*"> + llvm.call @free(%236) : (!llvm<"i8*">) -> () + %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*"> + llvm.call @free(%238) : (!llvm<"i8*">) -> () + %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*"> + llvm.call @free(%240) : (!llvm<"i8*">) -> () + llvm.return +} +``` + +See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for +more in-depth details on lowering to the LLVM dialect. + +# CodeGen: Getting Out of MLIR + +At this point we are right at the cusp of code generation. We can generate code +in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to +run it. + +## Emitting LLVM IR + +Now that our module is comprised only of operations in the LLVM dialect, we can +export to LLVM IR. To do this programmatically, we can invoke the following +utility: + +```c++ + std::unique_ptr llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) + /* ... an error was encountered ... */ +``` + +Exporting our module to LLVM IR generates: + +```.llvm +define void @main() { + ... + +102: + %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 + %104 = mul i64 %96, 2 + %105 = add i64 0, %104 + %106 = mul i64 %100, 1 + %107 = add i64 %105, %106 + %108 = getelementptr double, double* %103, i64 %107 + %109 = load double, double* %108 + %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) + %111 = add i64 %100, 1 + br label %99 + + ... + +115: + %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0 + %117 = bitcast double* %116 to i8* + call void @free(i8* %117) + %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0 + %119 = bitcast double* %118 to i8* + call void @free(i8* %119) + %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 + %121 = bitcast double* %120 to i8* + call void @free(i8* %121) + ret void +} +``` + +If we enable optimization on the generated LLVM IR, we can trim this down quite +a bit: + +```.llvm +define void @main() + %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00) + %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01) + %putchar = tail call i32 @putchar(i32 10) + %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00) + %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01) + %putchar.1 = tail call i32 @putchar(i32 10) + %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00) + %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01) + %putchar.2 = tail call i32 @putchar(i32 10) + ret void +} + +``` + +The full code listing for dumping LLVM IR can be found in `Ch6/toy.cpp` in the +`dumpLLVMIR()` function: + +```c++ + +int dumpLLVMIR(mlir::ModuleOp module) { + // Translate the module, that contains the LLVM dialect, to LLVM IR. + auto llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} +``` + +## Setting up a JIT + +Setting up a JIT to run the module containing the LLVM dialect can be done using +the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around +LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up +the JIT can be found in `Ch6/toy.cpp` in the `runJit()` function: + +```c++ +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} +``` + +You can play around with it from the build directory: + +```sh +$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit +1.000000 2.000000 +3.000000 4.000000 +``` + +You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and +`-emit=llvm` to compare the various levels of IR involved. Also try options like +[`--print-ir-after-all`](../../WritingAPass.md#ir-printing) to track the +evolution of the IR throughout the pipeline. + +So far, we have worked with primitive data types. In the +[next chapter](Ch-7.md), we will add a composite `struct` type. diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md new file mode 100644 index 0000000000000000000000000000000000000000..6298e8253e9a5e350f7b28df478b8a788cc0cefc --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -0,0 +1,539 @@ +# Chapter 7: Adding a Composite Type to Toy + +[TOC] + +In the [previous chapter](Ch-6.md), we demonstrated an end-to-end compilation +flow from our Toy front-end to LLVM IR. In this chapter, we will extend the Toy +language to support a new composite `struct` type. + +## Defining a `struct` in Toy + +The first thing we need to define is the interface of this type in our `toy` +source language. The general syntax of a `struct` type in Toy is as follows: + +```toy +# A struct is defined by using the `struct` keyword followed by a name. +struct MyStruct { + # Inside of the struct is a list of variable declarations without initializers + # or shapes, which may also be other previously defined structs. + var a; + var b; +} +``` + +Structs may now be used in functions as variables or parameters by using the +name of the struct instead of `var`. The members of the struct are accessed via +a `.` access operator. Values of `struct` type may be initialized with a +composite initializer, or a comma-separated list of other initializers +surrounded by `{}`. An example is shown below: + +```toy +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} +``` + +## Defining a `struct` in MLIR + +In MLIR, we will also need a representation for our struct types. MLIR does not +provide a type that does exactly what we need, so we will need to define our +own. We will simply define our `struct` as an unnamed container of a set of +element types. The name of the `struct` and its elements are only useful for the +AST of our `toy` compiler, so we don't need to encode it in the MLIR +representation. + +### Defining the Type Class + +#### Reserving a Range of Type Kinds + +Types in MLIR rely on having a unique `kind` value to ensure that casting checks +remain extremely efficient +([rationale](../../Rationale.md#reserving-dialect-type-kinds)). For `toy`, this +means we need to explicitly reserve a static range of type `kind` values in the +symbol registry file +[DialectSymbolRegistry](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/DialectSymbolRegistry.def). + +```c++ +DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect +DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect + +// The following ranges are reserved for experimenting with MLIR dialects in a +// private context without having to register them here. +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0) +``` + +These definitions will provide a range in the Type::Kind enum to use when +defining the derived types. + +```c++ +/// Create a local enumeration with all of the types that are defined by Toy. +namespace ToyTypes { +enum Types { + Struct = mlir::Type::FIRST_TOY_TYPE, +}; +} // end namespace ToyTypes +``` + +#### Defining the Type Class + +As mentioned in [chapter 2](Ch-2.md), [`Type`](../../LangRef.md#type-system) +objects in MLIR are value-typed and rely on having an internal storage object +that holds the actual data for the type. The `Type` class in itself acts as a +simple wrapper around an internal `TypeStorage` object that is uniqued within an +instance of an `MLIRContext`. When constructing a `Type`, we are internally just +constructing and uniquing an instance of a storage class. + +When defining a new `Type` that requires additional information beyond just the +`kind` (e.g. the `struct` type, which requires additional information to hold +the element types), we will need to provide a derived storage class. The +`primitive` types that don't have any additional data (e.g. the +[`index` type](../../LangRef.md#index-type)) don't require a storage class. + +##### Defining the Storage Class + +Type storage objects contain all of the data necessary to construct and unique a +type instance. Derived storage classes must inherit from the base +`mlir::TypeStorage` and provide a set of aliases and hooks that will be used by +the `MLIRContext` for uniquing. Below is the definition of the storage instance +for our `struct` type, with each of the necessary requirements detailed inline: + +```c++ +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself, see the `StructType::get` method further below. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef elementTypes; +}; +``` + +##### Defining the Type Class + +With the storage class defined, we can add the definition for the user-visible +`StructType` class. This is the class that we will actually interface with. + +```c++ +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be at least one element type. + static StructType get(llvm::ArrayRef elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first two parameters are the context to unique in and + // the kind of the type. The parameters after the type kind are forwarded to + // the storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, ToyTypes::Struct, elementTypes); + } + + /// Returns the element types of this struct type. + llvm::ArrayRef getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; + } + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } +}; +``` + +We register this type in the `ToyDialect` constructor in a similar way to how we +did with operations: + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx) { + addTypes(); +} +``` + +With this we can now use our `StructType` when generating MLIR from Toy. See +examples/toy/Ch7/mlir/MLIRGen.cpp for more details. + +### Parsing and Printing + +At this point we can use our `StructType` during MLIR generation and +transformation, but we can't output or parse `.mlir`. For this we need to add +support for parsing and printing instances of the `StructType`. This can be done +by overriding the `parseType` and `printType` methods on the `ToyDialect`. + +```c++ +class ToyDialect : public mlir::Dialect { +public: + /// Parse an instance of a type registered to the toy dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + + /// Print an instance of a type registered to the toy dialect. + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const override; +}; +``` + +These methods take an instance of a high-level parser or printer that allows for +easily implementing the necessary functionality. Before going into the +implementation, let's think about the syntax that we want for the `struct` type +in the printed IR. As described in the +[MLIR language reference](../../LangRef.md#dialect-types), dialect types are +generally represented as: `! dialect-namespace < type-data >`, with a pretty +form available under certain circumstances. The responsibility of our `Toy` +parser and printer is to provide the `type-data` bits. We will define our +`StructType` as having the following form: + +``` + struct-type ::= `struct` `<` type (`,` type)* `>` +``` + +#### Parsing + +An implementation of the parser is shown below: + +```c++ +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector elementTypes; + do { + // Parse the current element type. + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!elementType.isa() && + !elementType.isa()) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} +``` + +#### Printing + +An implementation of the printer is shown below: + +```c++ +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = type.cast(); + + // Print the struct type according to the parser format. + printer << "struct<"; + mlir::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; +} +``` + +Before moving on, let's look at a quick of example showcasing the functionality +we have now: + +```toy +struct Struct { + var a; + var b; +} + +def multiply_transpose(Struct value) { +} +``` + +Which generates the following: + +```mlir +module { + func @multiply_transpose(%arg0: !toy.struct, tensor<*xf64>>) { + "toy.return"() : () -> () + } +} +``` + +### Operating on `StructType` + +Now that the `struct` type has been defined, and we can round-trip it through +the IR. The next step is to add support for using it within our operations. + +#### Updating Existing Operations + +A few of our existing operations will need to be updated to handle `StructType`. +The first step is to make the ODS framework aware of our Type so that we can use +it in the operation definitions. A simple example is shown below: + +```tablegen +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. +def Toy_StructType : + Type()">, "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; +``` + +We can then update our operations, e.g. `ReturnOp`, to also accept the +`Toy_StructType`: + +```tablegen +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + ... + let arguments = (ins Variadic:$input); + ... +} +``` + +#### Adding New `Toy` Operations + +In addition to the existing operations, we will be adding a few new operations +that will provide more specific handling of `structs`. + +##### `toy.struct_constant` + +This new operation materializes a constant value for a struct. In our current +modeling, we just use an [array attribute](../../LangRef.md#array-attribute) +that contains a set of constant values for each of the `struct` elements. + +```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct> +``` + +##### `toy.struct_access` + +This new operation materializes the Nth element of a `struct` value. + +```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct>) -> tensor<*xf64> +``` + +With these operations, we can revisit our original example: + +```toy +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} +``` + +and finally get a full MLIR module: + +```mlir +module { + func @multiply_transpose(%arg0: !toy.struct, tensor<*xf64>>) -> tensor<*xf64> { + %0 = "toy.struct_access"(%arg0) {index = 0 : i64} : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + %2 = "toy.struct_access"(%arg0) {index = 1 : i64} : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> + %3 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64> + %4 = "toy.mul"(%1, %3) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%4) : (tensor<*xf64>) -> () + } + func @main() { + %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct, tensor<*xf64>> + %1 = "toy.generic_call"(%0) {callee = @multiply_transpose} : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> + "toy.print"(%1) : (tensor<*xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +#### Optimizing Operations on `StructType` + +Now that we have a few operations operating on `StructType`, we also have many +new constant folding opportunities. + +After inlining, the MLIR module in the previous section looks something like: + +```mlir +module { + func @main() { + %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct, tensor<*xf64>> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> + %2 = "toy.transpose"(%1) : (tensor<*xf64>) -> tensor<*xf64> + %3 = "toy.struct_access"(%0) {index = 1 : i64} : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> + %4 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64> + %5 = "toy.mul"(%2, %4) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +We have several `toy.struct_access` operations that access into a +`toy.struct_constant`. As detailed in [chapter 3](Ch-3.md), we can add folders +for these `toy` operations by setting the `hasFolder` bit on the operation +definition and providing a definition of the `*Op::fold` method. + +```c++ +/// Fold constants. +OpFoldResult ConstantOp::fold(ArrayRef operands) { return value(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(ArrayRef operands) { + return value(); +} + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(ArrayRef operands) { + auto structAttr = operands.front().dyn_cast_or_null(); + if (!structAttr) + return nullptr; + + size_t elementIndex = index().getZExtValue(); + return structAttr.getValue()[elementIndex]; +} +``` + +To ensure that MLIR generates the proper constant operations when folding our +`Toy` operations, i.e. `ConstantOp` for `TensorType` and `StructConstant` for +`StructType`, we will need to provide an override for the dialect hook +`materializeConstant`. This allows for generic MLIR operations to create +constants for the `Toy` dialect when necessary. + +```c++ +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa()) + return builder.create(loc, type, + value.cast()); + return builder.create(loc, type, + value.cast()); +} +``` + +With this, we can now generate code that can be generated to LLVM without any +changes to our pipeline. + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%2) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +You can build `toyc-ch7` and try yourself: `toyc-ch7 +test/Examples/Toy/Ch7/struct-codegen.toy -emit=mlir`. More details on defining +custom types can be found in +[DefiningAttributesAndTypes](../../DefiningAttributesAndTypes.md). diff --git a/mlir/docs/UsageOfConst.md b/mlir/docs/UsageOfConst.md new file mode 100644 index 0000000000000000000000000000000000000000..6e8ce78e960c256032bc508790c8e19b2123ed9b --- /dev/null +++ b/mlir/docs/UsageOfConst.md @@ -0,0 +1,272 @@ +# Usage of 'Const' in MLIR, for core IR types + +aka, where'd `const` go? + +The MLIR data structures that represent the IR itself (Instruction, Block, etc) +form a graph-based data structure, and the compiler analyses and passes +frequently walk this graph (e.g. traversing from defs to users). The early +design of MLIR adopted the `const` model of LLVM, which is familiar and well +understood (even though the LLVM implementation is flawed in many ways). + +The design team since decided to change to a different module, which eschews +`const` entirely for the core IR types: you should never see a `const` method on +`Operation`, should never see the type `const Value`, and you shouldn't feel bad +about this. That said, you *should* use `const` for non-IR types, like +`SmallVector`'s and many other things. + +The document below explains this design point from the viewpoint of "why make a +change", to explain the rationale and the tradeoffs involved that led us to this +potentially controversial design point. + +Bjarke Roune summarized the situation like this: + +> In my opinion `const` correctness is highly valuable, catching many bugs and +> making it clear in a code base where the mutations happen. In my opinion +> `const` correctness still isn't worth it in particular for IR elements because +> of the special uses and properties of IRs, in particular that it is common to +> transfer a pointer/reference to an instruction from an analysis to an +> optimization which will change the instruction. The analysis should be const, +> the optimization needs to get a non-`const` pointer. So all analyses either +> end up being templates (and if they never get instantiated in a const context, +> then the point of `const` correctness has been defeated), you need to somehow +> launder the const in a safe way or there will be `const_cast`s. These options +> are all bad, probably so bad as to out-weigh the benefits of const. + +# Reconsidering `const` in MLIR + +This document argues this design is introducing significant sub-optimalities +into the MLIR codebase, argues that the cost/benefit tradeoff of this design is +a poor tradeoff, and proposes switching to a much simpler approach - eliminating +the use of const of these IR types entirely. + +**Note:** **This document is only discussing things like `const Value` and +`const Operation*`. There is no proposed change for other types, e.g. +`SmallVector` references, the immutable types like `Attribute`, etc.** + +## Background: The LLVM Const Model + +The LLVM and MLIR data structures provide the IR data structures (like +`mlir::Operation`s and their users) as a structured cyclic graph data structure. +Clients of the IR typically walk up and down the graph, perform dynamic down +casting (of various sorts) to check for patterns, and use some high-abstraction +pattern matching and binding facilities to do their work. + +The basic idea of LLVM's design is that these traversals of the IR should +preserve the const'ness of a pointer: if you have a const pointer to an +instruction and ask for its parent (or operand, users, etc), you should get a +const pointer to the block containing the instruction (or value defining the +operand, instruction using the instruction, etc). The instruction class looks +like this: + +``` +namespace llvm { +class Instruction : ... { + BasicBlock *Parent; +public: + // A const instruction returns a const parent pointer. + inline const BasicBlock *getParent() const { return Parent; } + // A non-const instruction returns a non-const parent pointer. + inline BasicBlock *getParent() { return Parent; } +… +}; +} +``` + +The rationale for this design is that it would be const-incorrect to return a +non-const pointer from getParent, because you could then walk the block to find +the instruction again and get non-const references to the same instruction - all +without a `const_cast`. + +This const model is simple and the C++ type system generally supports it through +code duplication of methods. That said, LLVM is actually inconsistent and buggy +about this. Even the core classes have bugs: `llvm::Instruction::getOperand()` +isn't currently const correct! There are other subsystems (e.g. the +`llvm/IR/PatternMatch.h` APIs) where you can perform a pattern match on a const +IR object and bind a non-const IR object. + +LLVM is a mature technology with hundreds of people working on it. The fact that +it still isn't correctly following the const model it set out for strongly hints +that one of: 1) The design is too complicated to be practical, 2) the benefits +of the model aren't worth the cost of the complexity, or 3) both 1 and 2, +together in some combination. + +## Advantages of Const-correctness in MLIR + +Even though this doc argues for eliminating const from MLIR, it is important to +evaluate that as a tradeoff with the advantages the const model provides, +allowing us to do a cost/benefit tradeoff. These are the benefits we see: + +The major advantage of allowing const on MLIR types is as a marker in APIs that +indicate that the function will not modify the specified values. For example, +the dominator APIs have a `dominates(const Block*, const Block*)` method, and +the consts provide a way of indicating that the call won't modify the blocks +passed in - similarly predicates like `Instruction::isTerminator() const` do not +modify the receiver object. + +It is also an advantage that MLIR follows the generally prevailing pattern of +C++ code, which generally uses const. Consistency with the community norm is +important. + +## Costs of Const-correctness in MLIR + +As mentioned above, early work on MLIR adopted the same design as LLVM intended, +allowing const-correct traversals in the APIs. Here we discuss the various costs +of doing this by looking at some examples, listed in roughly increasing order of +severity. + +### Pervasively duplicated accessors + +Just as the getParent() example above shows, achieving this const model requires +that all of the graph traversal accessors be duplicated into const and non-const +versions. This causes API bloat and slows compile time, but these are minor +problems. + +The more significant issue is that this duplication can be so significant that +the signal disappears in the noise, for example `mlir::Operation` ends up with +things like this, which is twice as much API surface area just to try to satisfy +const. + +```c++ + operand_iterator operand_begin(); + operand_iterator operand_end(); + + /// Returns an iterator on the underlying Value's (Value ). + operand_range getOperands(); + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + using const_operand_range = llvm::iterator_range; + + const_operand_iterator operand_begin() const; + const_operand_iterator operand_end() const; + + /// Returns a const iterator on the underlying Value's (Value ). + llvm::iterator_range getOperands() const; + + ArrayRef getOpOperands() const { + return getOperandStorage().getOperands(); + } + MutableArrayRef getOpOperands() { + return getOperandStorage().getOperands(); + } + + OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; } + const OpOperand &getOpOperand(unsigned idx) const { + return getOpOperands()[idx]; + } + +``` + +### Templated accessors + +A related issue is that having to provide both const and non-const versions of +accessors leads to us having to turn more code into templates than would +otherwise be desirable. Things like `ResultIterator` and `ResultTypeIterator` +are templates *_only_* because they are generic over const and non-const +versions of types. This leads to them being defined inline in headers (instead +of in .cpp files). + +Thus, our const model is leading to more code in headers and more complexity in +the implementation. + +### Const incorrect in practice + +For some things, const is more trouble than it is worth, so they never get +updated. + +This means that certain API in practice don't provide a const variant, leading +to pervasive use of `const_cast` to drop the const qualifier. For example the +logic in `Matchers.h` doesn't support const pointers at all (b/123355851), even +though matching and binding values themselves makes perfect sense for both const +and non-const values. Actually fixing this would cause massive code bloat and +complexity. + +Other parts of the code are just outright incorrect. For example, the operation +cloning methods are defined on Operation like this: + +```C++ +Operation *clone(BlockAndValueMapping &mapper, MLIRContext *context) const; + +Operation *clone(MLIRContext *context) const; +``` + +While it makes sense for a clone method to be `const` conceptually (the original +operation isn't modified) this is a violation of the model, since the returned +operation must be mutable, and provides access to the full graph of operands as +the original operation, violating the graph based const model we were shooting +for. + +### The `OpPointer` and `ConstOpPointer` Classes + +The "typed operation" classes for registered operations (e.g. like `DimOp` for +the "std.dim" operation in standard ops) contain a pointer to an operation and +provide typed APIs for processing it. + +However, this is a problem for our current `const` design - `const DimOp` means +the pointer itself is immutable, not the pointee. The current solution for this +is the `OpPointer<>` and `ConstOpPointer<>` classes, which exist solely to +provide const correctness when referring to a typed operation. Instead of +referring to `DimOp` directly, we need to use `OpPointer` and +`ConstOpPointer` to preserve this constness. + +While `auto` hides many instances of these `OpPointer` classes, their presence +leads to extremely ugly APIs. It also obscures the fact that the user does not +have a direct `DimOp` object, creating easy pitfalls with subtly incorrect +semantics: + +```C++ +// OpPointer encodes unnecessary and superfluous information into the API. +SmallVector, 8> stripmineSink( + OpPointer forOp, uint64_t factor, + ArrayRef> targets); +// Compared to the much cleaner and easier to read... +SmallVector stripmineSink(AffineForOp forOp, uint64_t factor, + ArrayRef targets); + +// OpPointer is easy to misuse. +if (auto *dimOp = inst->dyn_cast()) { + // This is actually undefined behavior because dyn_cast actually returns + // OpPointer. OpPointer happily implicitly converts to DimOp * + // creating undefined behavior that will execute correctly most of the time. +} +``` + +It would be much better to eliminate them entirely, and just pass around `DimOp` +directly. For example, instead of: + +```C++ +LogicalResult mlir::getIndexSet(MutableArrayRef> forOps, + FlatAffineConstraints *domain) { + +``` + +It would be a lot nicer to just have: + +```c++ +LogicalResult mlir::getIndexSet(MutableArrayRef forOps, + FlatAffineConstraints *domain) { +``` + +Particularly since all of the `FooOp` classes are already semantically a smart +pointer to their underlying operation. + +## Proposal: Remove `const` from IR objects + +As we can see above, there is very little benefit to our const design and +significant cost, and given that the primary purpose of an IR is to represent +transformations of code, const is providing very little benefit. + +As such, we propose eliminating support for const references in MLIR. This +implies the following changes to the codebase: + +1. All of the const-duplicated accessors would be eliminated, e.g. + `Operation::getParent() const` would be removed. This is expected to remove + approximately ~130 lines of code from just Operation.h alone. +1. Const-only predicates would be changed to be non-const, e.g. + `Operation::isTerminator() const` would have the const removed. +1. Iterators and other types and functions that are templated to support + `const` can have those template arguments removed. +1. Types like `OpPointer` and `ConstOpPointer` that exist solely to propagate + const can be entirely removed from the codebase. +1. We can close bugs complaining about const incorrectness in the IR. diff --git a/mlir/docs/WritingAPass.md b/mlir/docs/WritingAPass.md new file mode 100644 index 0000000000000000000000000000000000000000..5119c469e20819a40d6fd41c33435d6f37ce5b7e --- /dev/null +++ b/mlir/docs/WritingAPass.md @@ -0,0 +1,835 @@ +# Writing a Pass + +[TOC] + +Passes represent the basic infrastructure for transformation and optimization. +This document provides a quickstart to the pass infrastructure in MLIR and how +to use it. + +See [MLIR specification](LangRef.md) for more information about MLIR and its +core aspects, such as the IR structure and operations. + +See [MLIR Rewrites](QuickstartRewrites.md) for a quick start on graph rewriting +in MLIR. If your transformation involves pattern matching operation DAGs, this +is a great place to start. + +## Operation Pass + +In MLIR, the main unit of abstraction and transformation is an +[operation](LangRef.md#operations). As such, the pass manager is designed to +work on instances of operations at different levels of nesting. The structure of +the [pass manager](#pass-manager), and the concept of nesting, is detailed +further below. All passes in MLIR derive from `OperationPass` and adhere to the +following restrictions; any noncompliance will lead to problematic behavior in +multithreaded and other advanced scenarios: + +* Modify anything within the parent block/region/operation/etc, outside of the + current operation being operated on. This includes adding or removing + operations from the parent block. +* Maintain pass state across invocations of `runOnOperation`. A pass may be + run on several different operations with no guarantee of execution order. + * When multithreading, a specific pass instance may not even execute on + all operations within the module. As such, a pass should not rely on + running on all operations. +* Modify the state of another operation not nested within the current + operation being operated on. + * Other threads may be operating on different operations within the module + simultaneously. +* Maintain any global mutable state, e.g. static variables within the source + file. All mutable state should be maintained by an instance of the pass. +* Must be copy-constructible, multiple instances of the pass may be created by + the pass manager to process operations in parallel. +* Inspect the IR of sibling operations. Other threads may be modifying these + operations in parallel. + +When creating an operation pass, there are two different types to choose from +depending on the usage scenario: + +### OperationPass : Op-Specific + +An `op-specific` operation pass operates explicitly on a given operation type. +This operation type must adhere to the restrictions set by the pass manager for +pass execution. + +To define an op-specific operation pass, a derived class must adhere to the +following: + +* Inherit from the CRTP class `OperationPass` and provide the operation type + as an additional template parameter. +* Override the virtual `void runOnOperation()` method. + +A simple pass may look like: + +```c++ +namespace { +struct MyFunctionPass : public OperationPass { + void runOnOperation() override { + // Get the current FuncOp operation being operated on. + FuncOp f = getOperation(); + + // Walk the operations within the function. + f.walk([](Operation *inst) { + .... + }); + } +}; +} // end anonymous namespace + +// Register this pass to make it accessible to utilities like mlir-opt. +// (Pass registration is discussed more below) +static PassRegistration pass( + "flag-name-to-invoke-pass-via-mlir-opt", "Pass description here"); +``` + +### OperationPass : Op-Agnostic + +An `op-agnostic` pass operates on the operation type of the pass manager that it +is added to. This means that a pass that operates on several different operation +types in the same way only needs one implementation. + +To create an operation pass, a derived class must adhere to the following: + +* Inherit from the CRTP class `OperationPass`. +* Override the virtual `void runOnOperation()` method. + +A simple pass may look like: + +```c++ +struct MyOperationPass : public OperationPass { + void runOnOperation() override { + // Get the current operation being operated on. + Operation *op = getOperation(); + ... + } +}; +``` + +## Analysis Management + +An important concept, along with transformation passes, are analyses. These are +conceptually similar to transformation passes, except that they compute +information on a specific operation without modifying it. In MLIR, analyses are +not passes but free-standing classes that are computed lazily on-demand and +cached to avoid unnecessary recomputation. An analysis in MLIR must adhere to +the following: + +* Provide a valid constructor taking an `Operation*`. +* Must not modify the given operation. + +An analysis may provide additional hooks to control various behavior: + +* `bool isInvalidated(const AnalysisManager::PreservedAnalyses &)` + +Given a preserved analysis set, the analysis returns true if it should truly be +invalidated. This allows for more fine-tuned invalidation in cases where an +analysis wasn't explicitly marked preserved, but may be preserved (or +invalidated) based upon other properties such as analyses sets. + +### Querying Analyses + +The base `OperationPass` class provide utilities for querying and preserving +analyses for the current operation being processed. + +* OperationPass automatically provides the following utilities for querying + analyses: + * `getAnalysis<>` + - Get an analysis for the current operation, constructing it if + necessary. + * `getCachedAnalysis<>` + - Get an analysis for the current operation, if it already exists. + * `getCachedParentAnalysis<>` + - Get an analysis for a given parent operation, if it exists. + * `getCachedChildAnalysis<>` + - Get an analysis for a given child operation, if it exists. + * `getChildAnalysis<>` + - Get an analysis for a given child operation, constructing it if + necessary. + +Using the example passes defined above, let's see some examples: + +```c++ +/// An interesting analysis. +struct MyOperationAnalysis { + // Compute this analysis with the provided operation. + MyOperationAnalysis(Operation *op); +}; + +void MyOperationPass::runOnOperation() { + // Query MyOperationAnalysis for the current operation. + MyOperationAnalysis &myAnalysis = getAnalysis(); + + // Query a cached instance of MyOperationAnalysis for the current operation. + // It will not be computed if it doesn't exist. + auto optionalAnalysis = getCachedAnalysis(); + if (optionalAnalysis) + ... + + // Query a cached instance of MyOperationAnalysis for the parent operation of + // the current operation. It will not be computed if it doesn't exist. + auto optionalAnalysis = getCachedParentAnalysis(); + if (optionalAnalysis) + ... +} +``` + +### Preserving Analyses + +Analyses that are constructed after being queried by a pass are cached to avoid +unnecessary computation if they are requested again later. To avoid stale +analyses, all analyses are assumed to be invalidated by a pass. To avoid +invalidation, a pass must specifically mark analyses that are known to be +preserved. + +* All Pass classes automatically provide the following utilities for + preserving analyses: + * `markAllAnalysesPreserved` + * `markAnalysesPreserved<>` + +```c++ +void MyOperationPass::runOnOperation() { + // Mark all analyses as preserved. This is useful if a pass can guarantee + // that no transformation was performed. + markAllAnalysesPreserved(); + + // Mark specific analyses as preserved. This is used if some transformation + // was performed, but some analyses were either unaffected or explicitly + // preserved. + markAnalysesPreserved(); +} +``` + +## Pass Failure + +Passes in MLIR are allowed to gracefully fail. This may happen if some invariant +of the pass was broken, potentially leaving the IR in some invalid state. If +such a situation occurs, the pass can directly signal a failure to the pass +manager. If a pass signaled a failure when executing, no other passes in the +pipeline will execute and the `PassManager::run` will return failure. Failure +signaling is provided in the form of a `signalPassFailure` method. + +```c++ +void MyPass::runOnOperation() { + // Signal failure on a broken invariant. + if (some_broken_invariant) { + signalPassFailure(); + return; + } +} +``` + +## Pass Manager + +Above we introduced the different types of passes and their constraints. Now +that we have our pass we need to be able to run it over a specific module. This +is where the pass manager comes into play. The `PassManager` class is used to +configure and run a pipeline. The `OpPassManager` class is used to schedule +passes to run at a specific level of nesting. + +### OpPassManager + +An `OpPassManager` is essentially a collection of passes to execute on an +operation of a given type. This operation type must adhere to the following +requirement: + +* Must be registered and marked `IsolatedFromAbove`. + + * Passes are expected to not modify operations at or above the current + operation being processed. If the operation is not isolated, it may + inadvertently modify the use-list of an operation it is not supposed to + modify. + +Passes can be added to a pass manager via `addPass`. The pass must either be an +`op-specific` pass operating on the same operation type as `OpPassManager`, or +an `op-agnostic` pass. + +An `OpPassManager` cannot be created directly, but must be explicitly nested +within another `OpPassManager` via the `nest<>` method. This method takes the +operation type that the nested pass manager will operate on. At the top-level, a +`PassManager` acts as an `OpPassManager` that operates on the +[`module`](LangRef.md#module) operation. Nesting in this sense, corresponds to +the structural nesting within [Regions](LangRef.md#regions) of the IR. + +For example, the following `.mlir`: + +``` +module { + spv.module "Logical" "GLSL450" { + func @foo() { + ... + } + } +} +``` + +Has the nesting structure of: + +``` +`module` + `spv.module` + `function` +``` + +Below is an example of constructing a pipeline that operates on the above +structure: + +```c++ +PassManager pm(ctx); + +// Add a pass on the top-level module operation. +pm.addPass(std::make_unique()); + +// Nest a pass manager that operates on spirv module operations nested directly +// under the top-level module. +OpPassManager &nestedModulePM = pm.nest(); +nestedModulePM.addPass(std::make_unique()); + +// Nest a pass manager that operates on functions within the nested SPIRV +// module. +OpPassManager &nestedFunctionPM = nestedModulePM.nest(); +nestedFunctionPM.addPass(std::make_unique()); + +// Run the pass manager on the top-level module. +Module m = ...; +if (failed(pm.run(m))) + ... // One of the passes signaled a failure. +``` + +The above pass manager would contain the following pipeline structure: + +``` +OpPassManager + MyModulePass + OpPassManager + MySPIRVModulePass + OpPassManager + MyFunctionPass +``` + +These pipelines are then run over a single operation at a time. This means that, +for example, given a series of consecutive passes on FuncOp, it will execute all +on the first function, then all on the second function, etc. until the entire +program has been run through the passes. This provides several benefits: + +* This improves the cache behavior of the compiler, because it is only + touching a single function at a time, instead of traversing the entire + program. +* This improves multi-threading performance by reducing the number of jobs + that need to be scheduled, as well as increasing the efficiency of each job. + An entire function pipeline can be run on each function asynchronously. + +## Pass Registration + +Briefly shown in the example definitions of the various pass types is the +`PassRegistration` class. This is a utility to register derived pass classes so +that they may be created, and inspected, by utilities like mlir-opt. Registering +a pass class takes the form: + +```c++ +static PassRegistration pass("command-line-arg", "description"); +``` + +* `MyPass` is the name of the derived pass class. +* "command-line-arg" is the argument to use on the command line to invoke the + pass from `mlir-opt`. +* "description" is a description of the pass. + +For passes that cannot be default-constructed, `PassRegistration` accepts an +optional third argument that takes a callback to create the pass: + +```c++ +static PassRegistration pass( + "command-line-arg", "description", + []() -> std::unique_ptr { + std::unique_ptr p = std::make_unique(/*options*/); + /*... non-trivial-logic to configure the pass ...*/; + return p; + }); +``` + +This variant of registration can be used, for example, to accept the +configuration of a pass from command-line arguments and pass it over to the pass +constructor. Make sure that the pass is copy-constructible in a way that does +not share data as the [pass manager](#pass-manager) may create copies of the +pass to run in parallel. + +### Pass Pipeline Registration + +Described above is the mechanism used for registering a specific derived pass +class. On top of that, MLIR allows for registering custom pass pipelines in a +similar fashion. This allows for custom pipelines to be available to tools like +mlir-opt in the same way that passes are, which is useful for encapsulating +common pipelines like the "-O1" series of passes. Pipelines are registered via a +similar mechanism to passes in the form of `PassPipelineRegistration`. Compared +to `PassRegistration`, this class takes an additional parameter in the form of a +pipeline builder that modifies a provided `OpPassManager`. + +```c++ +void pipelineBuilder(OpPassManager &pm) { + pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); +} + +// Register an existing pipeline builder function. +static PassPipelineRegistration<> pipeline( + "command-line-arg", "description", pipelineBuilder); + +// Register an inline pipeline builder. +static PassPipelineRegistration<> pipeline( + "command-line-arg", "description", [](OpPassManager &pm) { + pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); + }); +``` + +Pipeline registration also allows for simplified registration of +specifializations for existing passes: + +```c++ +static PassPipelineRegistration<> foo10( + "foo-10", "Foo Pass 10", [] { return std::make_unique(10); } ); +``` + +### Textual Pass Pipeline Specification + +In the previous sections, we showed how to register passes and pass pipelines +with a specific argument and description. Once registered, these can be used on +the command line to configure a pass manager. The limitation of using these +arguments directly is that they cannot build a nested pipeline. For example, if +our module has another module nested underneath, with just `-my-module-pass` +there is no way to specify that this pass should run on the nested module and +not the top-level module. This is due to the flattened nature of the command +line. + +To circumvent this limitation, MLIR also supports a textual description of a +pass pipeline. This allows for explicitly specifying the structure of the +pipeline to add to the pass manager. This includes the nesting structure, as +well as the passes and pass pipelines to run. A textual pipeline is defined as a +series of names, each of which may in itself recursively contain a nested +pipeline description. The syntax for this specification is as follows: + +```ebnf +pipeline ::= op-name `(` pipeline-element (`,` pipeline-element)* `)` +pipeline-element ::= pipeline | (pass-name | pass-pipeline-name) options? +options ::= '{' (key ('=' value)?)+ '}' +``` + +* `op-name` + * This corresponds to the mnemonic name of an operation to run passes on, + e.g. `func` or `module`. +* `pass-name` | `pass-pipeline-name` + * This corresponds to the command-line argument of a registered pass or + pass pipeline, e.g. `cse` or `canonicalize`. +* `options` + * Options are pass specific key value pairs that are handled as described + in the [instance specific pass options](#instance-specific-pass-options) + section. + +For example, the following pipeline: + +```shell +$ mlir-opt foo.mlir -cse -canonicalize -convert-std-to-llvm +``` + +Can also be specified as (via the `-pass-pipeline` flag): + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse, canonicalize), convert-std-to-llvm' +``` + +In order to support round-tripping your pass to the textual representation using +`OpPassManager::printAsTextualPipeline(raw_ostream&)`, override +`Pass::printAsTextualPipeline(raw_ostream&)` to format your pass-name and +options in the format described above. + +### Instance Specific Pass Options + +Options may be specified for a parametric pass. Individual options are defined +using the [LLVM command line](https://llvm.org/docs/CommandLine.html) flag +definition rules. These options will then be parsed at pass construction time +independently for each instance of the pass. To provide options for passes, the +`Option<>` and `OptionList<>` classes may be used: + +```c++ +struct MyPass ... { + /// Make sure that we have a valid default constructor and copy constructor to + /// make sure that the options are initialized properly. + MyPass() = default; + MyPass(const MyPass& pass) {} + + // These just forward onto llvm::cl::list and llvm::cl::opt respectively. + Option exampleOption{*this, "flag-name", llvm::cl::desc("...")}; + ListOption exampleListOption{*this, "list-flag-name", + llvm::cl::desc("...")}; +}; +``` + +For pass pipelines, the `PassPipelineRegistration` templates take an additional +optional template parameter that is the Option struct definition to be used for +that pipeline. To use pipeline specific options, create a class that inherits +from `mlir::PassPipelineOptions` that contains the desired options. When using +`PassPipelineRegistration`, the constructor now takes a function with the +signature `void (OpPassManager &pm, const MyPipelineOptions&)` which should +construct the passes from the options and pass them to the pm: + +```c++ +struct MyPipelineOptions : public PassPipelineOptions { + // These just forward onto llvm::cl::list and llvm::cl::opt respectively. + Option exampleOption{*this, "flag-name", llvm::cl::desc("...")}; + ListOption exampleListOption{*this, "list-flag-name", + llvm::cl::desc("...")}; +}; + + +static mlir::PassPipelineRegistration pipeline( + "example-pipeline", "Run an example pipeline.", + [](OpPassManager &pm, const MyPipelineOptions &pipelineOptions) { + // Initialize the pass manager. + }); +``` + +## Pass Statistics + +Statistics are a way to keep track of what the compiler is doing and how +effective various transformations are. It is often useful to see what effect +specific transformations have on a particular program, and how often they +trigger. Pass statistics are instance specific which allow for taking this a +step further as you are able to see the effect of placing a particular +transformation at specific places within the pass pipeline. For example, they +help answer questions like `What happens if I run CSE again here?`. + +Statistics can be added to a pass by using the 'Pass::Statistic' class. This +class takes as a constructor arguments: the parent pass, a name, and a +description. This class acts like an unsigned integer, and may be incremented +and updated accordingly. These statistics use the same infrastructure as +[`llvm::Statistic`](http://llvm.org/docs/ProgrammersManual.html#the-statistic-class-stats-option) +and thus have similar usage constraints. Collected statistics can be dumped by +the [pass manager](#pass-manager) programmatically via +`PassManager::enableStatistics`; or via `-pass-statistics` and +`-pass-statistics-display` on the command line. + +An example is shown below: + +```c++ +struct MyPass : public OperationPass { + Statistic testStat{this, "testStat", "A test statistic"}; + + void runOnOperation() { + ... + + // Update our statistic after some invariant was hit. + ++testStat; + + ... + } +}; +``` + +The collected statistics may be aggregated in two types of views: + +A pipeline view that models the structure of the pass manager, this is the +default view: + +```shell +$ mlir-opt -pass-pipeline='func(my-pass,my-pass)' foo.mlir -pass-statistics + +===-------------------------------------------------------------------------=== + ... Pass statistics report ... +===-------------------------------------------------------------------------=== +'func' Pipeline + MyPass + (S) 15 testStat - A test statistic + VerifierPass + MyPass + (S) 6 testStat - A test statistic + VerifierPass +VerifierPass +``` + +And a list view that aggregates all instances of a specific pass together: + +```shell +$ mlir-opt -pass-pipeline='func(my-pass, my-pass)' foo.mlir -pass-statistics -pass-statistics-display=list + +===-------------------------------------------------------------------------=== + ... Pass statistics report ... +===-------------------------------------------------------------------------=== +MyPass + (S) 21 testStat - A test statistic +``` + +## Pass Instrumentation + +MLIR provides a customizable framework to instrument pass execution and analysis +computation. This is provided via the `PassInstrumentation` class. This class +provides hooks into the PassManager that observe various pass events: + +* `runBeforePipeline` + * This callback is run just before a pass pipeline, i.e. pass manager, is + executed. +* `runAfterPipeline` + * This callback is run right after a pass pipeline has been executed, + successfully or not. +* `runBeforePass` + * This callback is run just before a pass is executed. +* `runAfterPass` + * This callback is run right after a pass has been successfully executed. + If this hook is executed, runAfterPassFailed will not be. +* `runAfterPassFailed` + * This callback is run right after a pass execution fails. If this hook is + executed, runAfterPass will not be. +* `runBeforeAnalysis` + * This callback is run just before an analysis is computed. +* `runAfterAnalysis` + * This callback is run right after an analysis is computed. + +PassInstrumentation objects can be registered directly with a +[PassManager](#pass-manager) instance via the `addInstrumentation` method. +Instrumentations added to the PassManager are run in a stack like fashion, i.e. +the last instrumentation to execute a `runBefore*` hook will be the first to +execute the respective `runAfter*` hook. Below in an example instrumentation +that counts the number of times DominanceInfo is computed: + +```c++ +struct DominanceCounterInstrumentation : public PassInstrumentation { + unsigned &count; + + DominanceCounterInstrumentation(unsigned &count) : count(count) {} + void runAfterAnalysis(llvm::StringRef, AnalysisID *id, Operation *) override { + if (id == AnalysisID::getID()) + ++count; + } +}; + +MLIRContext *ctx = ...; +PassManager pm(ctx); + +// Add the instrumentation to the pass manager. +unsigned domInfoCount; +pm.addInstrumentation( + std::make_unique(domInfoCount)); + +// Run the pass manager on a module operation. +ModuleOp m = ...; +if (failed(pm.run(m))) + ... + +llvm::errs() << "DominanceInfo was computed " << domInfoCount << " times!\n"; +``` + +### Standard Instrumentations + +MLIR utilizes the pass instrumentation framework to provide a few useful +developer tools and utilities. Each of these instrumentations are immediately +available to all users of the MLIR pass framework. + +#### Pass Timing + +The PassTiming instrumentation provides timing information about the execution +of passes and computation of analyses. This provides a quick glimpse into what +passes are taking the most time to execute, as well as how much of an effect +your pass has on the total execution time of the pipeline. Users can enable this +instrumentation directly on the PassManager via `enableTiming`. This +instrumentation is also made available in mlir-opt via the `-pass-timing` flag. +The PassTiming instrumentation provides several different display modes for the +timing results, each of which is described below: + +##### List Display Mode + +In this mode, the results are displayed in a list sorted by total time with each +pass/analysis instance aggregated into one unique result. This view is useful +for getting an overview of what analyses/passes are taking the most time in a +pipeline. This display mode is available in mlir-opt via +`-pass-timing-display=list`. + +```shell +$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list + +===-------------------------------------------------------------------------=== + ... Pass execution timing report ... +===-------------------------------------------------------------------------=== + Total Execution Time: 0.0203 seconds + + ---Wall Time--- --- Name --- + 0.0047 ( 55.9%) Canonicalizer + 0.0019 ( 22.2%) VerifierPass + 0.0016 ( 18.5%) LLVMLoweringPass + 0.0003 ( 3.4%) CSE + 0.0002 ( 1.9%) (A) DominanceInfo + 0.0084 (100.0%) Total +``` + +##### Pipeline Display Mode + +In this mode, the results are displayed in a nested pipeline view that mirrors +the internal pass pipeline that is being executed in the pass manager. This view +is useful for understanding specifically which parts of the pipeline are taking +the most time, and can also be used to identify when analyses are being +invalidated and recomputed. This is the default display mode. + +```shell +$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing + +===-------------------------------------------------------------------------=== + ... Pass execution timing report ... +===-------------------------------------------------------------------------=== + Total Execution Time: 0.0249 seconds + + ---Wall Time--- --- Name --- + 0.0058 ( 70.8%) 'func' Pipeline + 0.0004 ( 4.3%) CSE + 0.0002 ( 2.6%) (A) DominanceInfo + 0.0004 ( 4.8%) VerifierPass + 0.0046 ( 55.4%) Canonicalizer + 0.0005 ( 6.2%) VerifierPass + 0.0005 ( 5.8%) VerifierPass + 0.0014 ( 17.2%) LLVMLoweringPass + 0.0005 ( 6.2%) VerifierPass + 0.0082 (100.0%) Total +``` + +##### Multi-threaded Pass Timing + +When multi-threading is enabled in the pass manager the meaning of the display +slightly changes. First, a new timing column is added, `User Time`, that +displays the total time spent across all threads. Secondly, the `Wall Time` +column displays the longest individual time spent amongst all of the threads. +This means that the `Wall Time` column will continue to give an indicator on the +perceived time, or clock time, whereas the `User Time` will display the total +cpu time. + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing + +===-------------------------------------------------------------------------=== + ... Pass execution timing report ... +===-------------------------------------------------------------------------=== + Total Execution Time: 0.0078 seconds + + ---User Time--- ---Wall Time--- --- Name --- + 0.0177 ( 88.5%) 0.0057 ( 71.3%) 'func' Pipeline + 0.0044 ( 22.0%) 0.0015 ( 18.9%) CSE + 0.0029 ( 14.5%) 0.0012 ( 15.2%) (A) DominanceInfo + 0.0038 ( 18.9%) 0.0015 ( 18.7%) VerifierPass + 0.0089 ( 44.6%) 0.0025 ( 31.1%) Canonicalizer + 0.0006 ( 3.0%) 0.0002 ( 2.6%) VerifierPass + 0.0004 ( 2.2%) 0.0004 ( 5.4%) VerifierPass + 0.0013 ( 6.5%) 0.0013 ( 16.3%) LLVMLoweringPass + 0.0006 ( 2.8%) 0.0006 ( 7.0%) VerifierPass + 0.0200 (100.0%) 0.0081 (100.0%) Total +``` + +#### IR Printing + +When debugging it is often useful to dump the IR at various stages of a pass +pipeline. This is where the IR printing instrumentation comes into play. This +instrumentation allows for conditionally printing the IR before and after pass +execution by optionally filtering on the pass being executed. This +instrumentation can be added directly to the PassManager via the +`enableIRPrinting` method. `mlir-opt` provides a few useful flags for utilizing +this instrumentation: + +* `print-ir-before=(comma-separated-pass-list)` + * Print the IR before each of the passes provided within the pass list. +* `print-ir-before-all` + * Print the IR before every pass in the pipeline. + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-before=cse + +*** IR Dump Before CSE *** +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + %c1_i32_0 = constant 1 : i32 + return %c1_i32, %c1_i32_0 : i32, i32 +} +``` + +* `print-ir-after=(comma-separated-pass-list)` + * Print the IR after each of the passes provided within the pass list. +* `print-ir-after-all` + * Print the IR after every pass in the pipeline. + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-after=cse + +*** IR Dump After CSE *** +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + return %c1_i32, %c1_i32 : i32, i32 +} +``` + +* `print-ir-after-change` + * Only print the IR after a pass if the pass mutated the IR. This helps to + reduce the number of IR dumps for "uninteresting" passes. + * Note: Changes are detected by comparing a hash of the operation before + and after the pass. This adds additional run-time to compute the hash of + the IR, and in some rare cases may result in false-positives depending + on the collision rate of the hash algorithm used. + * Note: This option should be used in unison with one of the other + 'print-ir-after' options above, as this option alone does not enable + printing. + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse,cse)' -print-ir-after=cse -print-ir-after-change + +*** IR Dump After CSE *** +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + return %c1_i32, %c1_i32 : i32, i32 +} +``` + +* `print-ir-module-scope` + * Always print the top-level module operation, regardless of pass type or + operation nesting level. + * Note: Printing at module scope should only be used when multi-threading + is disabled(`-disable-pass-threading`) + +```shell +$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope + +*** IR Dump After CSE *** ('func' operation: @bar) +func @bar(%arg0: f32, %arg1: f32) -> f32 { + ... +} + +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + %c1_i32_0 = constant 1 : i32 + return %c1_i32, %c1_i32_0 : i32, i32 +} + +*** IR Dump After CSE *** ('func' operation: @simple_constant) +func @bar(%arg0: f32, %arg1: f32) -> f32 { + ... +} + +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + return %c1_i32, %c1_i32 : i32, i32 +} +``` + +## Crash and Failure Reproduction + +The [pass manager](#pass-manager) in MLIR contains a builtin mechanism to +generate reproducibles in the even of a crash, or a +[pass failure](#pass-failure). This functionality can be enabled via +`PassManager::enableCrashReproducerGeneration` or via the command line flag +`pass-pipeline-crash-reproducer`. In either case, an argument is provided that +corresponds to the output `.mlir` file name that the reproducible should be +written to. The reproducible contains the configuration of the pass manager that +was executing, as well as the initial IR before any passes were run. A potential +reproducible may have the form: + +```mlir +// configuration: -pass-pipeline='func(cse, canonicalize), inline' +// note: verifyPasses=false + +module { + func @foo() { + ... + } +} +``` diff --git a/mlir/docs/includes/img/index-map.svg b/mlir/docs/includes/img/index-map.svg new file mode 100644 index 0000000000000000000000000000000000000000..6004c2da362d1ec39b28cce73db3e937edc89a18 --- /dev/null +++ b/mlir/docs/includes/img/index-map.svg @@ -0,0 +1,380 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mlir/docs/includes/img/view-operation.svg b/mlir/docs/includes/img/view-operation.svg new file mode 100644 index 0000000000000000000000000000000000000000..f4d622ee263ce6db50358d42a34ab0177ea133e7 --- /dev/null +++ b/mlir/docs/includes/img/view-operation.svg @@ -0,0 +1,580 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..56002b1ad2e27aee3ca26a909e25b238599ae2d6 --- /dev/null +++ b/mlir/examples/toy/CMakeLists.txt @@ -0,0 +1,15 @@ +add_custom_target(Toy) +set_target_properties(Toy PROPERTIES FOLDER Examples) + +macro(add_toy_chapter name) + add_dependencies(Toy ${name}) + add_llvm_example(${name} ${ARGN}) +endmacro(add_toy_chapter name) + +add_subdirectory(Ch1) +add_subdirectory(Ch2) +add_subdirectory(Ch3) +add_subdirectory(Ch4) +add_subdirectory(Ch5) +add_subdirectory(Ch6) +add_subdirectory(Ch7) diff --git a/mlir/examples/toy/Ch1/CMakeLists.txt b/mlir/examples/toy/Ch1/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f4e85556130161b5eaf59f6b353b608eff9f7eb9 --- /dev/null +++ b/mlir/examples/toy/Ch1/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LLVM_LINK_COMPONENTS + Support + ) + +add_toy_chapter(toyc-ch1 + toyc.cpp + parser/AST.cpp + ) +include_directories(include/) +target_link_libraries(toyc-ch1 + PRIVATE + MLIRSupport) diff --git a/mlir/examples/toy/Ch1/include/toy/AST.h b/mlir/examples/toy/Ch1/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..820600b5b1c900cbeedce7545bad458f096cc92e --- /dev/null +++ b/mlir/examples/toy/Ch1/include/toy/AST.h @@ -0,0 +1,242 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch1/include/toy/Lexer.h b/mlir/examples/toy/Ch1/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..a77a91bb5645104b5474c680aca5368e18f130a0 --- /dev/null +++ b/mlir/examples/toy/Ch1/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purposes. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purposes (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4557ea26859de3d0a6b71448f4bef030167c3e71 --- /dev/null +++ b/mlir/examples/toy/Ch1/include/toy/Parser.h @@ -0,0 +1,485 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d6d9359529bffc068520bebf4a9ea56f436a415 --- /dev/null +++ b/mlir/examples/toy/Ch1/parser/AST.cpp @@ -0,0 +1,234 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48863fa931cd09d7216e262d55006ae341233775 --- /dev/null +++ b/mlir/examples/toy/Ch1/toyc.cpp @@ -0,0 +1,66 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); +namespace { +enum Action { None, DumpAST }; +} + +static cl::opt + emitAction("emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump"))); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + switch (emitAction) { + case Action::DumpAST: + dump(*moduleAST); + return 0; + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/examples/toy/Ch2/CMakeLists.txt b/mlir/examples/toy/Ch2/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7154902017eea2e262b79fc76171c0d6e1f597bd --- /dev/null +++ b/mlir/examples/toy/Ch2/CMakeLists.txt @@ -0,0 +1,21 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +add_toy_chapter(toyc-ch2 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + ) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +add_dependencies(toyc-ch2 ToyCh2OpsIncGen) +target_link_libraries(toyc-ch2 + PRIVATE + MLIRAnalysis + MLIRIR + MLIRParser + MLIRTransforms) diff --git a/mlir/examples/toy/Ch2/include/CMakeLists.txt b/mlir/examples/toy/Ch2/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/toy/Ch2/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch2/include/toy/AST.h b/mlir/examples/toy/Ch2/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..820600b5b1c900cbeedce7545bad458f096cc92e --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/AST.h @@ -0,0 +1,242 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch2/include/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c08f78b0e8c8a93390fb46c401499687fbc232a0 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_public_tablegen_target(ToyCh2OpsIncGen) diff --git a/mlir/examples/toy/Ch2/include/toy/Dialect.h b/mlir/examples/toy/Ch2/include/toy/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..385d6ddb95ac4f50f1d8e64c2a1306114affd258 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/Dialect.h @@ -0,0 +1,44 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" + +namespace mlir { +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/Lexer.h b/mlir/examples/toy/Ch2/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..6eff64ee5f09634041f76cbae11c18f8ca46d07c --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/MLIRGen.h b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c8ca1201d1a2a391c0aec0d89197fbbb18efb8 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h @@ -0,0 +1,32 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..aa7e94fcae77db5d9c3f18efbc02745ced5e4aa1 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -0,0 +1,220 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/IR/OpBase.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add"> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def GenericCallOp : Toy_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef arguments"> + ]; +} + +def MulOp : Toy_Op<"mul"> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); +} + +def ReshapeOp : Toy_Op<"reshape"> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose"> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch2/include/toy/Parser.h b/mlir/examples/toy/Ch2/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4557ea26859de3d0a6b71448f4bef030167c3e71 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/Parser.h @@ -0,0 +1,485 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b4d669d18eabd72acb11eea6b4dbc1c3dab4ecd --- /dev/null +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -0,0 +1,180 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = + op.getResult()->getType().dyn_cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9c960c79f47254d31e72037431a0f9d3a614276 --- /dev/null +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -0,0 +1,452 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule.push_back(func); + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d6d9359529bffc068520bebf4a9ea56f436a415 --- /dev/null +++ b/mlir/examples/toy/Ch2/parser/AST.cpp @@ -0,0 +1,234 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3e3db97b4aee74bc60fdb9b022eb8a183954ec44 --- /dev/null +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -0,0 +1,137 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int dumpMLIR() { + // Register our Dialect with MLIR. + mlir::registerDialect(); + + mlir::MLIRContext context; + + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + mlir::OwningModuleRef module = mlirGen(context, *moduleAST); + if (!module) + return 1; + + module->dump(); + return 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + mlir::OwningModuleRef module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/examples/toy/Ch3/CMakeLists.txt b/mlir/examples/toy/Ch3/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..823edfd343a3e06b05f7cd8f9bed930b54448dbb --- /dev/null +++ b/mlir/examples/toy/Ch3/CMakeLists.txt @@ -0,0 +1,31 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh3CombineIncGen) + +add_toy_chapter(toyc-ch3 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/ToyCombine.cpp + ) + +add_dependencies(toyc-ch3 ToyCh3OpsIncGen) +add_dependencies(toyc-ch3 ToyCh3CombineIncGen) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch3 + PRIVATE + MLIRAnalysis + MLIRIR + MLIRParser + MLIRPass + MLIRTransforms) + diff --git a/mlir/examples/toy/Ch3/include/CMakeLists.txt b/mlir/examples/toy/Ch3/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/toy/Ch3/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch3/include/toy/AST.h b/mlir/examples/toy/Ch3/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..820600b5b1c900cbeedce7545bad458f096cc92e --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/AST.h @@ -0,0 +1,242 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch3/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch3/include/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e76780c1f79f3901338b5f2b8e57c13265f7f387 --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_public_tablegen_target(ToyCh3OpsIncGen) diff --git a/mlir/examples/toy/Ch3/include/toy/Dialect.h b/mlir/examples/toy/Ch3/include/toy/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..385d6ddb95ac4f50f1d8e64c2a1306114affd258 --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/Dialect.h @@ -0,0 +1,44 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" + +namespace mlir { +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch3/include/toy/Lexer.h b/mlir/examples/toy/Ch3/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..6eff64ee5f09634041f76cbae11c18f8ca46d07c --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch3/include/toy/MLIRGen.h b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c8ca1201d1a2a391c0aec0d89197fbbb18efb8 --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h @@ -0,0 +1,32 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..80717119b2fe4deb84528863f9d69f7bc0502f14 --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -0,0 +1,226 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/IR/OpBase.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", [NoSideEffect]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def GenericCallOp : Toy_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef arguments"> + ]; +} + +def MulOp : Toy_Op<"mul", [NoSideEffect]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // Enabled registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + // Enabled registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch3/include/toy/Parser.h b/mlir/examples/toy/Ch3/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4557ea26859de3d0a6b71448f4bef030167c3e71 --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/Parser.h @@ -0,0 +1,485 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b4d669d18eabd72acb11eea6b4dbc1c3dab4ecd --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -0,0 +1,180 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = + op.getResult()->getType().dyn_cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9c960c79f47254d31e72037431a0f9d3a614276 --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -0,0 +1,452 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule.push_back(func); + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e320540217935bab1df63e1afb0878c3fb03a000 --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -0,0 +1,69 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.td b/mlir/examples/toy/Ch3/mlir/ToyCombine.td new file mode 100644 index 0000000000000000000000000000000000000000..e6e33e84d7e8f3e13aea9840f3690029de025d94 --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.td @@ -0,0 +1,62 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d6d9359529bffc068520bebf4a9ea56f436a415 --- /dev/null +++ b/mlir/examples/toy/Ch3/parser/AST.cpp @@ -0,0 +1,234 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8b6e94786bed91eb2a3c3dcfd963daa5efbfdb4 --- /dev/null +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -0,0 +1,157 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int dumpMLIR() { + // Register our Dialect with MLIR. + mlir::registerDialect(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + llvm::SourceMgr sourceMgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + if (int error = loadMLIR(sourceMgr, context, module)) + return error; + + if (enableOpt) { + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Add a run of the canonicalizer to optimize the mlir module. + pm.addNestedPass(mlir::createCanonicalizerPass()); + if (mlir::failed(pm.run(*module))) + return 4; + } + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d11e5abcf13037f608f67c0a047e1f30ccf9c57e --- /dev/null +++ b/mlir/examples/toy/Ch4/CMakeLists.txt @@ -0,0 +1,35 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh4CombineIncGen) + +add_toy_chapter(toyc-ch4 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + ) + +add_dependencies(toyc-ch4 ToyCh4OpsIncGen) +add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch4 ToyCh4CombineIncGen) +add_dependencies(toyc-ch4 MLIRCallOpInterfacesIncGen) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch4 + PRIVATE + MLIRAnalysis + MLIRIR + MLIRParser + MLIRPass + MLIRTransforms) + diff --git a/mlir/examples/toy/Ch4/include/CMakeLists.txt b/mlir/examples/toy/Ch4/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/toy/Ch4/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch4/include/toy/AST.h b/mlir/examples/toy/Ch4/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..820600b5b1c900cbeedce7545bad458f096cc92e --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/AST.h @@ -0,0 +1,242 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..798d0df1d8d685f0ffd97d70eac806794cfd2503 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh4OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh4ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8b91dcf4843762db80cde22ef96a0b22929840 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h @@ -0,0 +1,46 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" +#include "toy/ShapeInferenceInterface.h" + +namespace mlir { +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/Lexer.h b/mlir/examples/toy/Ch4/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..6eff64ee5f09634041f76cbae11c18f8ca46d07c --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/MLIRGen.h b/mlir/examples/toy/Ch4/include/toy/MLIRGen.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c8ca1201d1a2a391c0aec0d89197fbbb18efb8 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/MLIRGen.h @@ -0,0 +1,32 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..dfb11cf23b9aa7dc514f4e8610e04f138b8ba35f --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -0,0 +1,246 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Analysis/CallInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def CastOp : Toy_Op<"cast", + [DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef arguments"> + ]; +} + +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch4/include/toy/Parser.h b/mlir/examples/toy/Ch4/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4557ea26859de3d0a6b71448f4bef030167c3e71 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/Parser.h @@ -0,0 +1,485 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch4/include/toy/Passes.h b/mlir/examples/toy/Ch4/include/toy/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..93c51309008fca3771099b863854f0fe9e5655e5 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/Passes.h @@ -0,0 +1,27 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createShapeInferencePass(); +std::unique_ptr createDeadFunctionEliminationPass(); +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..da0fb66018ee4df1882d26f074ecd49a24ddcea9 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000000000000000000000000000000000000..1b38ada1622862057ad2c18eabe147b875e18cf2 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ee34547860cd98c27c21da874ad794a6d0c99d5 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,59 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. +class DeadFunctionEliminationPass + : public mlir::ModulePass { +public: + void runOnModule() override { + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps())) { + if (func != mainFn) + func.erase(); + } + } +}; +} // end anonymous namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a9ded0c3d38ae810d6dd114f4c3a0d85df65b60 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -0,0 +1,261 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = + op.getResult()->getType().dyn_cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// CastOp + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9c960c79f47254d31e72037431a0f9d3a614276 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -0,0 +1,452 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule.push_back(func); + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..517a1f075306485003e099ed805a23f77cb49147 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -0,0 +1,104 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +class ShapeInferencePass : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); + } +}; +} // end anonymous namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82c247c1be2d4da5ac4419f9267381379f8b365c --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -0,0 +1,74 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// Fold simple cast operations that return the same type as the input. +OpFoldResult CastOp::fold(ArrayRef operands) { + return mlir::impl::foldCastOp(*this); +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.td b/mlir/examples/toy/Ch4/mlir/ToyCombine.td new file mode 100644 index 0000000000000000000000000000000000000000..e6e33e84d7e8f3e13aea9840f3690029de025d94 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.td @@ -0,0 +1,62 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d6d9359529bffc068520bebf4a9ea56f436a415 --- /dev/null +++ b/mlir/examples/toy/Ch4/parser/AST.cpp @@ -0,0 +1,234 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7b584407f65627837129ef66ad864fe04115029 --- /dev/null +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -0,0 +1,167 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int dumpMLIR() { + // Register our Dialect with MLIR. + mlir::registerDialect(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + llvm::SourceMgr sourceMgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + if (int error = loadMLIR(sourceMgr, context, module)) + return error; + + if (enableOpt) { + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + if (mlir::failed(pm.run(*module))) + return 4; + } + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..df5239589de24040eb54a2af475ebbe35e16c0ee --- /dev/null +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -0,0 +1,42 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh5CombineIncGen) + +add_toy_chapter(toyc-ch5 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp + mlir/LowerToAffineLoops.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + ) + +add_dependencies(toyc-ch5 ToyCh5ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch5 ToyCh5OpsIncGen) +add_dependencies(toyc-ch5 ToyCh5CombineIncGen) +add_dependencies(toyc-ch5 MLIRCallOpInterfacesIncGen) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch5 + PRIVATE + MLIRAffineOps + MLIRAnalysis + MLIRIR + MLIRParser + MLIRPass + MLIRStandardOps + MLIRTransforms) + +whole_archive_link(toyc-ch5 + MLIRAffineOps + MLIRStandardOps + ) diff --git a/mlir/examples/toy/Ch5/include/CMakeLists.txt b/mlir/examples/toy/Ch5/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/toy/Ch5/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch5/include/toy/AST.h b/mlir/examples/toy/Ch5/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..820600b5b1c900cbeedce7545bad458f096cc92e --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/AST.h @@ -0,0 +1,242 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch5/include/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..aaa932896d0f17e2a78f5336d3eda2bd11d285a7 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh5OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8b91dcf4843762db80cde22ef96a0b22929840 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h @@ -0,0 +1,46 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" +#include "toy/ShapeInferenceInterface.h" + +namespace mlir { +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/Lexer.h b/mlir/examples/toy/Ch5/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..6eff64ee5f09634041f76cbae11c18f8ca46d07c --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/MLIRGen.h b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c8ca1201d1a2a391c0aec0d89197fbbb18efb8 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h @@ -0,0 +1,32 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..410c5df246128bd8ddba8bc264a0ab9df9f65941 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -0,0 +1,247 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Analysis/CallInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def CastOp : Toy_Op<"cast", + [DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef arguments"> + ]; +} + +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch5/include/toy/Parser.h b/mlir/examples/toy/Ch5/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4557ea26859de3d0a6b71448f4bef030167c3e71 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/Parser.h @@ -0,0 +1,485 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch5/include/toy/Passes.h b/mlir/examples/toy/Ch5/include/toy/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..97a5d0db46c5b8fad86035a3b20a18d852ca84a5 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/Passes.h @@ -0,0 +1,32 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createDeadFunctionEliminationPass(); +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..da0fb66018ee4df1882d26f074ecd49a24ddcea9 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000000000000000000000000000000000000..1b38ada1622862057ad2c18eabe147b875e18cf2 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ee34547860cd98c27c21da874ad794a6d0c99d5 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,59 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. +class DeadFunctionEliminationPass + : public mlir::ModulePass { +public: + void runOnModule() override { + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps())) { + if (func != mainFn) + func.erase(); + } + } +}; +} // end anonymous namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a9ded0c3d38ae810d6dd114f4c3a0d85df65b60 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -0,0 +1,261 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = + op.getResult()->getType().dyn_cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// CastOp + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d6e76de069ce235033287496a0ed556789fcf4a --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,309 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input a rewriter, an array of memRefOperands corresponding +/// to the operands of the input operation, and the set of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; + +static void lowerOpToLoops(Operation *op, ArrayRef operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create an empty affine loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (auto dim : tensorType.getShape()) { + auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body and update the rewriter insertion point to the + // beginning of the loop. + rewriter.setInsertionPointToStart(loop.getBody()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to the processing function with the rewriter, the memref + // operands, and the loop induction variables. This function will return the + // value to store at the current index. + Value valueToStore = processIteration(rewriter, operands, loopIvs); + rewriter.create(loc, valueToStore, alloc, + llvm::makeArrayRef(loopIvs)); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the BinaryOp. This + // allows for using the nice named accessors that are generated by the + // ODS. + typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the inner + // loop. + auto loadedLhs = + rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + auto loadedRhs = + rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + + // Create the binary operation performed on the loaded values. + return rewriter.create(loc, loadedLhs, loadedRhs); + }); + return matchSuccess(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.value(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = op.getType().cast(); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::makeArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return matchFailure(); + + // We lower "toy.return" directly to "std.return". + rewriter.replaceOpWithNewOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. + toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass : public FunctionPass { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void ToyToAffineLoweringPass::runOnFunction() { + auto function = getFunction(); + + // We only lower the main function as we expect that all other functions have + // been inlined. + if (function.getName() != "main") + return; + + // Verify that the given main has no inputs and results. + if (function.getNumArguments() || function.getType().getNumResults()) { + function.emitError("expected 'main' to have 0 inputs and 0 results"); + return signalPassFailure(); + } + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9c960c79f47254d31e72037431a0f9d3a614276 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -0,0 +1,452 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule.push_back(func); + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..517a1f075306485003e099ed805a23f77cb49147 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -0,0 +1,104 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +class ShapeInferencePass : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); + } +}; +} // end anonymous namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82c247c1be2d4da5ac4419f9267381379f8b365c --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -0,0 +1,74 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// Fold simple cast operations that return the same type as the input. +OpFoldResult CastOp::fold(ArrayRef operands) { + return mlir::impl::foldCastOp(*this); +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.td b/mlir/examples/toy/Ch5/mlir/ToyCombine.td new file mode 100644 index 0000000000000000000000000000000000000000..e6e33e84d7e8f3e13aea9840f3690029de025d94 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.td @@ -0,0 +1,62 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d6d9359529bffc068520bebf4a9ea56f436a415 --- /dev/null +++ b/mlir/examples/toy/Ch5/parser/AST.cpp @@ -0,0 +1,234 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..836968e218871e35e4ab8e06a5fb0544d954a30a --- /dev/null +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -0,0 +1,188 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine }; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int dumpMLIR() { + // Register our Dialect with MLIR. + mlir::registerDialect(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + llvm::SourceMgr sourceMgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + if (int error = loadMLIR(sourceMgr, context, module)) + return error; + + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect with a few cleanups afterwards. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::createLoopFusionPass()); + optPM.addPass(mlir::createMemRefDataFlowOptPass()); + } + } + + if (mlir::failed(pm.run(*module))) + return 4; + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + case Action::DumpMLIRAffine: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c342ed1d4a03fe5b316a6e2c8e90e3296f8a5d12 --- /dev/null +++ b/mlir/examples/toy/Ch6/CMakeLists.txt @@ -0,0 +1,53 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Core + Support + ) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh6CombineIncGen) + +add_toy_chapter(toyc-ch6 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp + mlir/LowerToAffineLoops.cpp + mlir/LowerToLLVM.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + ) + +add_dependencies(toyc-ch6 ToyCh6ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch6 ToyCh6OpsIncGen) +add_dependencies(toyc-ch6 ToyCh6CombineIncGen) +add_dependencies(toyc-ch6 MLIRCallOpInterfacesIncGen) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch6 + PRIVATE + MLIRAffineOps + MLIRAffineToStandard + MLIRAnalysis + MLIRExecutionEngine + MLIRIR + MLIRLLVMIR + MLIRLoopToStandard + MLIRParser + MLIRPass + MLIRStandardOps + MLIRStandardToLLVM + MLIRTargetLLVMIR + MLIRTransforms + ) + +whole_archive_link(toyc-ch6 + MLIRAffineToStandard + MLIRAffineOps + MLIRLLVMIR + MLIRStandardOps + ) diff --git a/mlir/examples/toy/Ch6/include/CMakeLists.txt b/mlir/examples/toy/Ch6/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/toy/Ch6/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch6/include/toy/AST.h b/mlir/examples/toy/Ch6/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..820600b5b1c900cbeedce7545bad458f096cc92e --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/AST.h @@ -0,0 +1,242 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch6/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch6/include/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..aecf11fab6c94d392e6a83244cc0ed4cd3fb4b14 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh6OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh6ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8b91dcf4843762db80cde22ef96a0b22929840 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h @@ -0,0 +1,46 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" +#include "toy/ShapeInferenceInterface.h" + +namespace mlir { +namespace toy { + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch6/include/toy/Lexer.h b/mlir/examples/toy/Ch6/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..6eff64ee5f09634041f76cbae11c18f8ca46d07c --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch6/include/toy/MLIRGen.h b/mlir/examples/toy/Ch6/include/toy/MLIRGen.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c8ca1201d1a2a391c0aec0d89197fbbb18efb8 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/MLIRGen.h @@ -0,0 +1,32 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..410c5df246128bd8ddba8bc264a0ab9df9f65941 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -0,0 +1,247 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Analysis/CallInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def CastOp : Toy_Op<"cast", + [DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef arguments"> + ]; +} + +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch6/include/toy/Parser.h b/mlir/examples/toy/Ch6/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4557ea26859de3d0a6b71448f4bef030167c3e71 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/Parser.h @@ -0,0 +1,485 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch6/include/toy/Passes.h b/mlir/examples/toy/Ch6/include/toy/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..33c2021c8db298671d41987e10de508507065f15 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/Passes.h @@ -0,0 +1,36 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createDeadFunctionEliminationPass(); +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr createLowerToLLVMPass(); + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..da0fb66018ee4df1882d26f074ecd49a24ddcea9 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000000000000000000000000000000000000..1b38ada1622862057ad2c18eabe147b875e18cf2 --- /dev/null +++ b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ee34547860cd98c27c21da874ad794a6d0c99d5 --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,59 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. +class DeadFunctionEliminationPass + : public mlir::ModulePass { +public: + void runOnModule() override { + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps())) { + if (func != mainFn) + func.erase(); + } + } +}; +} // end anonymous namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a9ded0c3d38ae810d6dd114f4c3a0d85df65b60 --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -0,0 +1,261 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = + op.getResult()->getType().dyn_cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// CastOp + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d6e76de069ce235033287496a0ed556789fcf4a --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,309 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input a rewriter, an array of memRefOperands corresponding +/// to the operands of the input operation, and the set of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; + +static void lowerOpToLoops(Operation *op, ArrayRef operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create an empty affine loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (auto dim : tensorType.getShape()) { + auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body and update the rewriter insertion point to the + // beginning of the loop. + rewriter.setInsertionPointToStart(loop.getBody()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to the processing function with the rewriter, the memref + // operands, and the loop induction variables. This function will return the + // value to store at the current index. + Value valueToStore = processIteration(rewriter, operands, loopIvs); + rewriter.create(loc, valueToStore, alloc, + llvm::makeArrayRef(loopIvs)); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the BinaryOp. This + // allows for using the nice named accessors that are generated by the + // ODS. + typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the inner + // loop. + auto loadedLhs = + rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + auto loadedRhs = + rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + + // Create the binary operation performed on the loaded values. + return rewriter.create(loc, loadedLhs, loadedRhs); + }); + return matchSuccess(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.value(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = op.getType().cast(); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::makeArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return matchFailure(); + + // We lower "toy.return" directly to "std.return". + rewriter.replaceOpWithNewOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. + toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass : public FunctionPass { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void ToyToAffineLoweringPass::runOnFunction() { + auto function = getFunction(); + + // We only lower the main function as we expect that all other functions have + // been inlined. + if (function.getName() != "main") + return; + + // Verify that the given main has no inputs and results. + if (function.getNumArguments() || function.getType().getNumResults()) { + function.emitError("expected 'main' to have 0 inputs and 0 results"); + return signalPassFailure(); + } + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f1a6ae8bbee6c850f6c1e26e6c595b34a19b5ab --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -0,0 +1,204 @@ +//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToLLVM RewritePatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual +/// elements of the array. +class PrintOpLowering : public ConversionPattern { +public: + explicit PrintOpLowering(MLIRContext *context) + : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto memRefType = (*op->operand_type_begin()).cast(); + auto memRefShape = memRefType.getShape(); + auto loc = op->getLoc(); + auto *llvmDialect = + op->getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + ModuleOp parentModule = op->getParentOfType(); + + // Get a symbol reference to the printf function, inserting it if necessary. + auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); + Value formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, + llvmDialect); + Value newLineCst = getOrCreateGlobalString( + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); + + // Create a loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body. + rewriter.setInsertionPointToStart(loop.getBody()); + + // Insert a newline after each of the inner dimensions of the shape. + if (i != e - 1) + rewriter.create(loc, printfRef, rewriter.getIntegerType(32), + newLineCst); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to printf for the current element of the loop. + auto printOp = cast(op); + auto elementLoad = rewriter.create(loc, printOp.input(), loopIvs); + rewriter.create(loc, printfRef, rewriter.getIntegerType(32), + ArrayRef({formatSpecifierCst, elementLoad})); + + // Notify the rewriter that this operation has been removed. + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + /// Return a symbol reference to the printf function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get("printf", context); + + // Create a function declaration for printf, the signature is: + // * `i32 (i8*, ...)` + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", llvmFnType); + return SymbolRefAttr::get("printf", context); + } + + /// Return a value representing an access into a global string with the given + /// name, creating the string if necessary. + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + // Create the global at the entry of the module. + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMType::getArrayTy( + LLVM::LLVMType::getInt8Ty(llvmDialect), value.size()); + global = builder.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value)); + } + + // Get the pointer to the first character in the global string. + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create( + loc, LLVM::LLVMType::getInt64Ty(llvmDialect), + builder.getIntegerAttr(builder.getIndexType(), 0)); + return builder.create( + loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, + ArrayRef({cst0, cst0})); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ToyToLLVMLoweringPass +//===----------------------------------------------------------------------===// + +namespace { +struct ToyToLLVMLoweringPass : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace + +void ToyToLLVMLoweringPass::runOnModule() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. For this lowering, we are only targeting + // the LLVM dialect. + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + + // During this lowering, we will also be lowering the MemRef types, that are + // currently being operated on, to a representation in LLVM. Do perform this + // conversion we use a TypeConverter as part of the lowering. This converter + // details how one type maps to another. This is necessary now that we will be + // doing more complicated lowerings, involving loop region arguments. + LLVMTypeConverter typeConverter(&getContext()); + + // Now that the conversion target has been defined, we need to provide the + // patterns used for lowering. At this point of the compilation process, we + // have a combination of `toy`, `affine`, and `std` operations. Luckily, there + // are already exists a set of patterns to transform `affine` and `std` + // dialects. These patterns lowering in multiple stages, relying on transitive + // lowerings. Transitive lowering, or A->B->C lowering, is when multiple + // patterns must be applied to fully transform an illegal operation into a + // set of legal ones. + OwningRewritePatternList patterns; + populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.insert(&getContext()); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + auto module = getModule(); + if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + signalPassFailure(); +} + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr mlir::toy::createLowerToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9c960c79f47254d31e72037431a0f9d3a614276 --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -0,0 +1,452 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule.push_back(func); + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..517a1f075306485003e099ed805a23f77cb49147 --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp @@ -0,0 +1,104 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +class ShapeInferencePass : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); + } +}; +} // end anonymous namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82c247c1be2d4da5ac4419f9267381379f8b365c --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -0,0 +1,74 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// Fold simple cast operations that return the same type as the input. +OpFoldResult CastOp::fold(ArrayRef operands) { + return mlir::impl::foldCastOp(*this); +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.td b/mlir/examples/toy/Ch6/mlir/ToyCombine.td new file mode 100644 index 0000000000000000000000000000000000000000..e6e33e84d7e8f3e13aea9840f3690029de025d94 --- /dev/null +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.td @@ -0,0 +1,62 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d6d9359529bffc068520bebf4a9ea56f436a415 --- /dev/null +++ b/mlir/examples/toy/Ch6/parser/AST.cpp @@ -0,0 +1,234 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e5b2afb7c65c5b6fcf9b10817c320b9bfdc11b2 --- /dev/null +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -0,0 +1,274 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT +}; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect with a few cleanups afterwards. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::createLoopFusionPass()); + optPM.addPass(mlir::createMemRefDataFlowOptPass()); + } + } + + if (isLoweringToLLVM) { + // Finish lowering the toy IR to the LLVM dialect. + pm.addPass(mlir::toy::createLowerToLLVMPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) { + auto llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + + // Register our Dialect with MLIR. + mlir::registerDialect(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; +} diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5956d7f4d9b7193371cad01b56e9fde00ecef716 --- /dev/null +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -0,0 +1,53 @@ +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Core + Support + ) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh7CombineIncGen) + +add_toy_chapter(toyc-ch7 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp + mlir/LowerToAffineLoops.cpp + mlir/LowerToLLVM.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + ) + +add_dependencies(toyc-ch7 ToyCh7ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch7 ToyCh7OpsIncGen) +add_dependencies(toyc-ch7 ToyCh7CombineIncGen) +add_dependencies(toyc-ch7 MLIRCallOpInterfacesIncGen) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch7 + PRIVATE + MLIRAffineOps + MLIRAffineToStandard + MLIRAnalysis + MLIRExecutionEngine + MLIRIR + MLIRLLVMIR + MLIRLoopToStandard + MLIRParser + MLIRPass + MLIRStandardOps + MLIRStandardToLLVM + MLIRTargetLLVMIR + MLIRTransforms + ) + +whole_archive_link(toyc-ch7 + MLIRAffineToStandard + MLIRAffineOps + MLIRLLVMIR + MLIRStandardOps + ) diff --git a/mlir/examples/toy/Ch7/include/CMakeLists.txt b/mlir/examples/toy/Ch7/include/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..37c89d0bae965cfc8665515de7e60ad7867a7d8b --- /dev/null +++ b/mlir/examples/toy/Ch7/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch7/include/toy/AST.h b/mlir/examples/toy/Ch7/include/toy/AST.h new file mode 100644 index 0000000000000000000000000000000000000000..3d3ae89dbeb2d03adaaa98c40d112d1e1a1285ab --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/AST.h @@ -0,0 +1,308 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable type with either name or shape information. +struct VarType { + std::string name; + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_StructLiteral, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for a literal struct value. +class StructLiteralExprAST : public ExprAST { + std::vector> values; + +public: + StructLiteralExprAST(Location loc, + std::vector> values) + : ExprAST(Expr_StructLiteral, loc), values(std::move(values)) {} + + llvm::ArrayRef> getValues() { return values; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { + return c->getKind() == Expr_StructLiteral; + } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal = nullptr) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::None; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a top level record in a module. +class RecordAST { +public: + enum RecordASTKind { + Record_Function, + Record_Struct, + }; + + RecordAST(RecordASTKind kind) : kind(kind) {} + virtual ~RecordAST() = default; + + RecordASTKind getKind() const { return kind; } + +private: + const RecordASTKind kind; +}; + +/// This class represents a function definition itself. +class FunctionAST : public RecordAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : RecordAST(Record_Function), proto(std::move(proto)), + body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } + + /// LLVM style RTTI + static bool classof(const RecordAST *R) { + return R->getKind() == Record_Function; + } +}; + +/// This class represents a struct definition. +class StructAST : public RecordAST { + Location location; + std::string name; + std::vector> variables; + +public: + StructAST(Location location, const std::string &name, + std::vector> variables) + : RecordAST(Record_Struct), location(location), name(name), + variables(std::move(variables)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getVariables() { + return variables; + } + + /// LLVM style RTTI + static bool classof(const RecordAST *R) { + return R->getKind() == Record_Struct; + } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector> records; + +public: + ModuleAST(std::vector> records) + : records(std::move(records)) {} + + auto begin() -> decltype(records.begin()) { return records.begin(); } + auto end() -> decltype(records.end()) { return records.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa30bd2e8e03eae897f5b7110703bb811125662e --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh7OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh7ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..77481b1884fab8ce35fa3864d3c0fcb0303ffc51 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -0,0 +1,100 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" +#include "toy/ShapeInferenceInterface.h" + +namespace mlir { +namespace toy { +namespace detail { +struct StructTypeStorage; +} // end namespace detail + +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. +class ToyDialect : public mlir::Dialect { +public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// A hook used to materialize constant values with the given type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; + + /// Parse an instance of a type registered to the toy dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + + /// Print an instance of a type registered to the toy dialect. + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const override; + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +/// Create a local enumeration with all of the types that are defined by Toy. +namespace ToyTypes { +enum Types { + Struct = mlir::Type::FIRST_TOY_TYPE, +}; +} // end namespace ToyTypes + +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be atleast one element type. + static StructType get(llvm::ArrayRef elementTypes); + + /// Returns the element types of this struct type. + llvm::ArrayRef getElementTypes(); + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } +}; +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/Lexer.h b/mlir/examples/toy/Ch7/include/toy/Lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..b41b82f2a0a7064351b99ce6575bddce9b5dc96e --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Lexer.h @@ -0,0 +1,235 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + tok_struct = -5, + + // primary + tok_identifier = -6, + tok_number = -7, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "struct") + return tok_struct; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9] ([0-9.])* + if (isdigit(lastChar)) { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/MLIRGen.h b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c8ca1201d1a2a391c0aec0d89197fbbb18efb8 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h @@ -0,0 +1,32 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..15395c6da4e61058afd36da4a7c860593a8f4ca1 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -0,0 +1,300 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Analysis/CallInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. +def Toy_StructType : + Type()">, "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", + [NoSideEffect, DeclareOpInterfaceMethods]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &state, " + "DenseElementsAttr value", [{ + build(builder, state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &state, double value"> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; + + // Set the folder bit so that we can implement constant folders. + let hasFolder = 1; +} + +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def CastOp : Toy_Op<"cast", + [DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType or + // StructType. + let results = (outs Toy_Type); + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "StringRef callee, ArrayRef arguments"> + ]; +} + +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &state", [{ build(b, state, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> { + let summary = "struct access"; + let description = [{ + Access the Nth element of a value returning a struct type. + }]; + + let arguments = (ins Toy_StructType:$input, I64Attr:$index); + let results = (outs Toy_Type); + + // Allow building a StructAccessOp with just a struct value and an index. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input, size_t index"> + ]; + + let verifier = [{ return ::verify(*this); }]; + + // Set the folder bit so that we can fold constant accesses. + let hasFolder = 1; +} + +def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> { + let summary = "struct constant"; + let description = [{ + Constant operation turns a literal struct value into an SSA value. The data + is attached to the operation as an attribute. The struct constant is encoded + as an array of other constant values. For example: + + ```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct> + ``` + }]; + + let hasFolder = 1; + let arguments = (ins ArrayAttr:$value); + let results = (outs Toy_StructType); + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<"Builder *b, OperationState &state, Value input"> + ]; + + // Invoke a static verify method to verify this transpose operation. + let verifier = [{ return ::verify(*this); }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch7/include/toy/Parser.h b/mlir/examples/toy/Ch7/include/toy/Parser.h new file mode 100644 index 0000000000000000000000000000000000000000..d2659e04dacb028c092264bdd8db91acf91ac518 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Parser.h @@ -0,0 +1,678 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions and structs one at a time and accumulate in this vector. + std::vector> records; + while (true) { + std::unique_ptr record; + switch (lexer.getCurToken()) { + case tok_eof: + break; + case tok_def: + record = parseDefinition(); + break; + case tok_struct: + record = parseStruct(); + break; + default: + return parseError("'def' or 'struct'", + "when parsing top level module records"); + } + if (!record) + break; + records.push_back(std::move(record)); + } + + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(records)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// Parse a literal struct expression. + /// structLiteral ::= { (structLiteral | tensorLiteral)+ } + std::unique_ptr parseStructLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('{')); + + // Hold the list of values. + std::vector> values; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; + } else if (lexer.getCurToken() == tok_number) { + values.push_back(parseNumberExpr()); + if (!values.back()) + return nullptr; + } else { + if (lexer.getCurToken() != '{') + return parseError("{, [, or number", + "in struct literal expression"); + values.push_back(parseStructLiteralExpr()); + } + + // End of this list on '}' + if (lexer.getCurToken() == '}') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("} or ,", "in struct literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", + "to fill struct literal expression"); + lexer.getNextToken(); // eat } + + return std::make_unique(std::move(loc), + std::move(values)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// Parse a call expression. + std::unique_ptr parseCallExpr(llvm::StringRef name, + const Location &loc) { + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + return parseCallExpr(name, loc); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case '{': + return parseStructLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse either a variable declaration or a call expression. + std::unique_ptr parseDeclarationOrCallExpr() { + auto loc = lexer.getLastLocation(); + std::string id = lexer.getId(); + lexer.consume(tok_identifier); + + // Check for a call expression. + if (lexer.getCurToken() == '(') + return parseCallExpr(id, loc); + + // Otherwise, this is a variable declaration. + return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); + } + + /// Parse a typed variable declaration. + std::unique_ptr + parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, + const Location &loc) { + // Parse the variable name. + if (lexer.getCurToken() != tok_identifier) + return parseError("name", "in variable declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + // Parse the initializer. + std::unique_ptr expr; + if (requiresInitializer) { + if (lexer.getCurToken() != '=') + return parseError("initializer", + "in variable declaration"); + lexer.consume(Token('=')); + expr = parseExpression(); + } + + VarType type; + type.name = typeName; + return std::make_unique(loc, std::move(id), std::move(type), + std::move(expr)); + } + + /// Parse a variable declaration, for either a tensor value or a struct value, + /// with an optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + /// decl ::= identifier identifier (= expr)? + std::unique_ptr parseDeclaration(bool requiresInitializer) { + // Check to see if this is a 'var' declaration. + if (lexer.getCurToken() == tok_var) + return parseVarDeclaration(requiresInitializer); + + // Parse the type name. + if (lexer.getCurToken() != tok_identifier) + return parseError("type name", "in variable declaration"); + auto loc = lexer.getLastLocation(); + std::string typeName = lexer.getId(); + lexer.getNextToken(); // eat id + + // Parse the rest of the declaration. + return parseTypedDeclaration(typeName, requiresInitializer, loc); + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + std::unique_ptr + parseVarDeclaration(bool requiresInitializer) { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + if (!type) + type = std::make_unique(); + + std::unique_ptr expr; + if (requiresInitializer) { + lexer.consume(Token('=')); + expr = parseExpression(); + } + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_identifier) { + // Variable declaration or call + auto expr = parseDeclarationOrCallExpr(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } else if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(/*requiresInitializer=*/true); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + VarType type; + std::string name; + + // Parse either the name of the variable, or its type. + std::string nameOrType = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + + // If the next token is an identifier, we just parsed the type. + if (lexer.getCurToken() == tok_identifier) { + type.name = std::move(nameOrType); + + // Parse the name. + name = lexer.getId(); + lexer.consume(tok_identifier); + } else { + // Otherwise, we just parsed the name. + name = std::move(nameOrType); + } + + args.push_back( + std::make_unique(std::move(loc), name, type)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Parse a struct definition, we expect a struct initiated with the + /// `struct` keyword, followed by a block containing a list of variable + /// declarations. + /// + /// definition ::= `struct` identifier `{` decl+ `}` + std::unique_ptr parseStruct() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_struct); + if (lexer.getCurToken() != tok_identifier) + return parseError("name", "in struct definition"); + std::string name = lexer.getId(); + lexer.consume(tok_identifier); + + // Parse: '{' + if (lexer.getCurToken() != '{') + return parseError("{", "in struct definition"); + lexer.consume(Token('{')); + + // Parse: decl+ + std::vector> decls; + do { + auto decl = parseDeclaration(/*requiresInitializer=*/false); + if (!decl) + return nullptr; + decls.push_back(std::move(decl)); + + if (lexer.getCurToken() != ';') + return parseError(";", + "after variable in struct definition"); + lexer.consume(Token(';')); + } while (lexer.getCurToken() != '}'); + + // Parse: '}' + lexer.consume(Token('}')); + return std::make_unique(loc, name, std::move(decls)); + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + case '.': + return 60; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch7/include/toy/Passes.h b/mlir/examples/toy/Ch7/include/toy/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..33c2021c8db298671d41987e10de508507065f15 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/Passes.h @@ -0,0 +1,36 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PASSES_H +#define MLIR_TUTORIAL_TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createDeadFunctionEliminationPass(); +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr createLowerToLLVMPass(); + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..da0fb66018ee4df1882d26f074ecd49a24ddcea9 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000000000000000000000000000000000000..1b38ada1622862057ad2c18eabe147b875e18cf2 --- /dev/null +++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ee34547860cd98c27c21da874ad794a6d0c99d5 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,59 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. +class DeadFunctionEliminationPass + : public mlir::ModulePass { +public: + void runOnModule() override { + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps())) { + if (func != mainFn) + func.erase(); + } + } +}; +} // end anonymous namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e37f61a4739d51a1bd9e806a5c01d6f88ffd3c5 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -0,0 +1,474 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); + addTypes(); +} + +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa()) + return builder.create(loc, type, + value.cast()); + return builder.create(loc, type, + value.cast()); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verify that the given attribute value is valid for the given type. +static mlir::LogicalResult verifyConstantForType(mlir::Type type, + mlir::Attribute opaqueValue, + mlir::Operation *op) { + if (type.isa()) { + // Check that the value is a elements attribute. + auto attrValue = opaqueValue.dyn_cast(); + if (!attrValue) + return op->emitError("constant of TensorType must be initialized by " + "a DenseFPElementsAttr, got ") + << opaqueValue; + + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = type.dyn_cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the + // constant result type. + auto attrType = attrValue.getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op->emitOpError("return type must match the one of the attached " + "value attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op->emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); + } + auto resultType = type.cast(); + llvm::ArrayRef resultElementTypes = resultType.getElementTypes(); + + // Verify that the initializer is an Array. + auto attrValue = opaqueValue.dyn_cast(); + if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) + return op->emitError("constant of StructType must be initialized by an " + "ArrayAttr with the same number of elements, got ") + << opaqueValue; + + // Check that each of the elements are valid. + llvm::ArrayRef attrElementValues = attrValue.getValue(); + for (const auto &it : llvm::zip(resultElementTypes, attrElementValues)) + if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) + return mlir::failure(); + return mlir::success(); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + return verifyConstantForType(op.getResult()->getType(), op.value(), op); +} + +static mlir::LogicalResult verify(StructConstantOp op) { + return verifyConstantForType(op.getResult()->getType(), op.value(), op); +} + +/// Infer the output shape of the ConstantOp, this is required by the shape +/// inference interface. +void ConstantOp::inferShapes() { getResult()->setType(value().getType()); } + +//===----------------------------------------------------------------------===// +// AddOp + +void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// CastOp + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +//===----------------------------------------------------------------------===// +// GenericCallOp + +void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } + +//===----------------------------------------------------------------------===// +// MulOp + +void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +//===----------------------------------------------------------------------===// +// StructAccessOp + +void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state, + mlir::Value input, size_t index) { + // Extract the result type from the input type. + StructType structTy = input->getType().cast(); + assert(index < structTy.getNumElementTypes()); + mlir::Type resultType = structTy.getElementTypes()[index]; + + // Call into the auto-generated build method. + build(b, state, resultType, input, b->getI64IntegerAttr(index)); +} + +static mlir::LogicalResult verify(StructAccessOp op) { + StructType structTy = op.input()->getType().cast(); + size_t index = op.index().getZExtValue(); + if (index >= structTy.getNumElementTypes()) + return op.emitOpError() + << "index should be within the range of the input struct type"; + mlir::Type resultType = op.getResult()->getType(); + if (resultType != structTy.getElementTypes()[index]) + return op.emitOpError() << "must have the same result type as the struct " + "element referred to by the index"; + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TransposeOp + +void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +static mlir::LogicalResult verify(TransposeOp op) { + auto inputType = op.getOperand()->getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return op.emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { +namespace detail { +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage, see the `StructType::get` method. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef elementTypes; +}; +} // end namespace detail +} // end namespace toy +} // end namespace mlir + +/// Create an instance of a `StructType` with the given element types. There +/// *must* be at least one element type. +StructType StructType::get(llvm::ArrayRef elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first two parameters are the context to unique in and the + // kind of the type. The parameters after the type kind are forwarded to the + // storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, ToyTypes::Struct, elementTypes); +} + +/// Returns the element types of this struct type. +llvm::ArrayRef StructType::getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; +} + +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector elementTypes; + do { + // Parse the current element type. + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!elementType.isa() && + !elementType.isa()) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} + +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = type.cast(); + + // Print the struct type according to the parser format. + printer << "struct<"; + mlir::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d6e76de069ce235033287496a0ed556789fcf4a --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,309 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input a rewriter, an array of memRefOperands corresponding +/// to the operands of the input operation, and the set of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; + +static void lowerOpToLoops(Operation *op, ArrayRef operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create an empty affine loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (auto dim : tensorType.getShape()) { + auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body and update the rewriter insertion point to the + // beginning of the loop. + rewriter.setInsertionPointToStart(loop.getBody()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to the processing function with the rewriter, the memref + // operands, and the loop induction variables. This function will return the + // value to store at the current index. + Value valueToStore = processIteration(rewriter, operands, loopIvs); + rewriter.create(loc, valueToStore, alloc, + llvm::makeArrayRef(loopIvs)); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the BinaryOp. This + // allows for using the nice named accessors that are generated by the + // ODS. + typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the inner + // loop. + auto loadedLhs = + rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + auto loadedRhs = + rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + + // Create the binary operation performed on the loaded values. + return rewriter.create(loc, loadedLhs, loadedRhs); + }); + return matchSuccess(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.value(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = op.getType().cast(); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::makeArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return matchFailure(); + + // We lower "toy.return" directly to "std.return". + rewriter.replaceOpWithNewOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. + toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass : public FunctionPass { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void ToyToAffineLoweringPass::runOnFunction() { + auto function = getFunction(); + + // We only lower the main function as we expect that all other functions have + // been inlined. + if (function.getName() != "main") + return; + + // Verify that the given main has no inputs and results. + if (function.getNumArguments() || function.getType().getNumResults()) { + function.emitError("expected 'main' to have 0 inputs and 0 results"); + return signalPassFailure(); + } + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f1a6ae8bbee6c850f6c1e26e6c595b34a19b5ab --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -0,0 +1,204 @@ +//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToLLVM RewritePatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual +/// elements of the array. +class PrintOpLowering : public ConversionPattern { +public: + explicit PrintOpLowering(MLIRContext *context) + : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto memRefType = (*op->operand_type_begin()).cast(); + auto memRefShape = memRefType.getShape(); + auto loc = op->getLoc(); + auto *llvmDialect = + op->getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + ModuleOp parentModule = op->getParentOfType(); + + // Get a symbol reference to the printf function, inserting it if necessary. + auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); + Value formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, + llvmDialect); + Value newLineCst = getOrCreateGlobalString( + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); + + // Create a loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body. + rewriter.setInsertionPointToStart(loop.getBody()); + + // Insert a newline after each of the inner dimensions of the shape. + if (i != e - 1) + rewriter.create(loc, printfRef, rewriter.getIntegerType(32), + newLineCst); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to printf for the current element of the loop. + auto printOp = cast(op); + auto elementLoad = rewriter.create(loc, printOp.input(), loopIvs); + rewriter.create(loc, printfRef, rewriter.getIntegerType(32), + ArrayRef({formatSpecifierCst, elementLoad})); + + // Notify the rewriter that this operation has been removed. + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + /// Return a symbol reference to the printf function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get("printf", context); + + // Create a function declaration for printf, the signature is: + // * `i32 (i8*, ...)` + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", llvmFnType); + return SymbolRefAttr::get("printf", context); + } + + /// Return a value representing an access into a global string with the given + /// name, creating the string if necessary. + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + // Create the global at the entry of the module. + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMType::getArrayTy( + LLVM::LLVMType::getInt8Ty(llvmDialect), value.size()); + global = builder.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value)); + } + + // Get the pointer to the first character in the global string. + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create( + loc, LLVM::LLVMType::getInt64Ty(llvmDialect), + builder.getIntegerAttr(builder.getIndexType(), 0)); + return builder.create( + loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, + ArrayRef({cst0, cst0})); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ToyToLLVMLoweringPass +//===----------------------------------------------------------------------===// + +namespace { +struct ToyToLLVMLoweringPass : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace + +void ToyToLLVMLoweringPass::runOnModule() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. For this lowering, we are only targeting + // the LLVM dialect. + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + + // During this lowering, we will also be lowering the MemRef types, that are + // currently being operated on, to a representation in LLVM. Do perform this + // conversion we use a TypeConverter as part of the lowering. This converter + // details how one type maps to another. This is necessary now that we will be + // doing more complicated lowerings, involving loop region arguments. + LLVMTypeConverter typeConverter(&getContext()); + + // Now that the conversion target has been defined, we need to provide the + // patterns used for lowering. At this point of the compilation process, we + // have a combination of `toy`, `affine`, and `std` operations. Luckily, there + // are already exists a set of patterns to transform `affine` and `std` + // dialects. These patterns lowering in multiple stages, relying on transitive + // lowerings. Transitive lowering, or A->B->C lowering, is when multiple + // patterns must be applied to fully transform an illegal operation into a + // set of legal ones. + OwningRewritePatternList patterns; + populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.insert(&getContext()); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + auto module = getModule(); + if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + signalPassFailure(); +} + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr mlir::toy::createLowerToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d543f69bdc6a0ab2aca2364ee2f91c8cbe2140e --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -0,0 +1,674 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::makeArrayRef; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (auto &record : moduleAST) { + if (FunctionAST *funcAST = llvm::dyn_cast(record.get())) { + auto func = mlirGen(*funcAST); + if (!func) + return nullptr; + + theModule.push_back(func); + functionMap.insert({func.getName(), func}); + } else if (StructAST *str = llvm::dyn_cast(record.get())) { + if (failed(mlirGen(*str))) + return nullptr; + } else { + llvm_unreachable("unknown record type"); + } + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable> + symbolTable; + using SymbolTableScopeT = + llvm::ScopedHashTableScope>; + + /// A mapping for the functions that have been code generated to MLIR. + llvm::StringMap functionMap; + + /// A mapping for named struct types to the underlying MLIR type and the + /// original AST node. + llvm::StringMap> structMap; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(Location loc) { + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { + if (symbolTable.count(var.getName())) + return mlir::failure(); + symbolTable.insert(var.getName(), {value, &var}); + return mlir::success(); + } + + /// Create an MLIR type for the given struct. + mlir::LogicalResult mlirGen(StructAST &str) { + if (structMap.count(str.getName())) + return emitError(loc(str.loc())) << "error: struct type with name `" + << str.getName() << "' already exists"; + + auto variables = str.getVariables(); + std::vector elementTypes; + elementTypes.reserve(variables.size()); + for (auto &variable : variables) { + if (variable->getInitVal()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + if (!variable->getType().shape.empty()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + + mlir::Type type = getType(variable->getType(), variable->loc()); + if (!type) + return mlir::failure(); + elementTypes.push_back(type); + } + + structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + llvm::SmallVector argTypes; + argTypes.reserve(proto.getArgs().size()); + for (auto &arg : proto.getArgs()) { + mlir::Type type = getType(arg->getType(), arg->loc()); + if (!type) + return nullptr; + argTypes.push_back(type); + } + auto func_type = builder.getFunctionType(argTypes, llvm::None); + return mlir::FuncOp::create(location, proto.getName(), func_type); + } + + /// Emit a new function and add it to the MLIR module. + mlir::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + SymbolTableScopeT var_scope(symbolTable); + + // Create an MLIR function for the given prototype. + mlir::FuncOp function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + auto &entryBlock = *function.addEntryBlock(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(*std::get<0>(name_value), std::get<1>(name_value)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + *returnOp.operand_type_begin())); + } + + return function; + } + + /// Return the struct type that is the result of the given expression, or null + /// if it cannot be inferred. + StructAST *getStructFor(ExprAST *expr) { + llvm::StringRef structName; + if (auto *decl = llvm::dyn_cast(expr)) { + auto varIt = symbolTable.lookup(decl->getName()); + if (!varIt.first) + return nullptr; + structName = varIt.second->getType().name; + } else if (auto *access = llvm::dyn_cast(expr)) { + if (access->getOp() != '.') + return nullptr; + // The name being accessed should be in the RHS. + auto *name = llvm::dyn_cast(access->getRHS()); + if (!name) + return nullptr; + StructAST *parentStruct = getStructFor(access->getLHS()); + if (!parentStruct) + return nullptr; + + // Get the element within the struct corresponding to the name. + VarDeclExprAST *decl = nullptr; + for (auto &var : parentStruct->getVariables()) { + if (var->getName() == name->getName()) { + decl = var.get(); + break; + } + } + if (!decl) + return nullptr; + structName = decl->getType().name; + } + if (structName.empty()) + return nullptr; + + // If the struct name was valid, check for an entry in the struct map. + auto structIt = structMap.find(structName); + if (structIt == structMap.end()) + return nullptr; + return structIt->second.second; + } + + /// Return the numeric member index of the given struct access expression. + llvm::Optional getMemberIndex(BinaryExprAST &accessOp) { + assert(accessOp.getOp() == '.' && "expected access operation"); + + // Lookup the struct node for the LHS. + StructAST *structAST = getStructFor(accessOp.getLHS()); + if (!structAST) + return llvm::None; + + // Get the name from the RHS. + VariableExprAST *name = llvm::dyn_cast(accessOp.getRHS()); + if (!name) + return llvm::None; + + auto structVars = structAST->getVariables(); + auto it = llvm::find_if(structVars, [&](auto &var) { + return var->getName() == name->getName(); + }); + if (it == structVars.end()) + return llvm::None; + return it - structVars.begin(); + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + auto location = loc(binop.loc()); + + // If this is an access operation, handle it immediately. + if (binop.getOp() == '.') { + llvm::Optional accessIndex = getMemberIndex(binop); + if (!accessIndex) { + emitError(location, "invalid access into struct expression"); + return nullptr; + } + return builder.create(location, lhs, *accessIndex); + } + + // Otherwise, this is a normal binary op. + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName()).first) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); + } + + /// Emit a constant for a literal/constant array. It will be emitted as a + /// flattened array of data in an Attribute attached to a `toy.constant` + /// operation. See documentation on [Attributes](LangRef.md#attributes) for + /// more details. Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); + } + mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { + // The type of this attribute is tensor of 64-bit floating-point with no + // shape. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get({}, elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, + llvm::makeArrayRef(lit.getValue())); + } + /// Emit a constant for a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. This function returns the generated constant, along with the + /// corresponding struct type. + std::pair + getConstantAttr(StructLiteralExprAST &lit) { + std::vector attrElements; + std::vector typeElements; + + for (auto &var : lit.getValues()) { + if (auto *number = llvm::dyn_cast(var.get())) { + attrElements.push_back(getConstantAttr(*number)); + typeElements.push_back(getType(llvm::None)); + } else if (auto *lit = llvm::dyn_cast(var.get())) { + attrElements.push_back(getConstantAttr(*lit)); + typeElements.push_back(getType(llvm::None)); + } else { + auto *structLit = llvm::cast(var.get()); + auto attrTypePair = getConstantAttr(*structLit); + attrElements.push_back(attrTypePair.first); + typeElements.push_back(attrTypePair.second); + } + } + mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); + mlir::Type dataType = StructType::get(typeElements); + return std::make_pair(dataAttr, dataType); + } + + /// Emit an array literal. + mlir::Value mlirGen(LiteralExprAST &lit) { + mlir::Type type = getType(lit.getDims()); + mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Emit a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. + mlir::Value mlirGen(StructLiteralExprAST &lit) { + mlir::ArrayAttr dataAttr; + mlir::Type dataType; + std::tie(dataAttr, dataType) = getConstantAttr(lit); + + // Build the MLIR op `toy.struct_constant`. This invokes the + // `StructConstantOp::build` method. + return builder.create(loc(lit.loc()), dataType, dataAttr); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + auto calledFuncIt = functionMap.find(callee); + if (calledFuncIt == functionMap.end()) { + emitError(location) << "no defined function found for '" << callee << "'"; + return nullptr; + } + mlir::FuncOp calledFunc = calledFuncIt->second; + return builder.create( + location, calledFunc.getType().getResult(0), + builder.getSymbolRefAttr(callee), operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_StructLiteral: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // Handle the case where we are initializing a struct value. + VarType varType = vardecl.getType(); + if (!varType.name.empty()) { + // Check that the initializer type is the same as the variable + // declaration. + mlir::Type type = getType(varType, vardecl.loc()); + if (!type) + return nullptr; + if (type != value->getType()) { + emitError(loc(vardecl.loc())) + << "struct type of initializer is different than the variable " + "declaration. Got " + << value->getType() << ", but expected " << type; + return nullptr; + } + + // Otherwise, we have the initializer value, but in case the variable was + // declared with specific shape, we emit a "reshape" operation. It will + // get optimized out later as needed. + } else if (!varType.shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(varType.shape), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl, value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + SymbolTableScopeT var_scope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above for non-struct types). + mlir::Type getType(const VarType &type, const Location &location) { + if (!type.name.empty()) { + auto it = structMap.find(type.name); + if (it == structMap.end()) { + emitError(loc(location)) + << "error: unknown struct type '" << type.name << "'"; + return nullptr; + } + return it->second.first; + } + + return getType(type.shape); + } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..517a1f075306485003e099ed805a23f77cb49147 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -0,0 +1,104 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +class ShapeInferencePass : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); + } +}; +} // end anonymous namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c688a53d86f7db586916127e922b345379d836fe --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -0,0 +1,92 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// Fold simple cast operations that return the same type as the input. +OpFoldResult CastOp::fold(ArrayRef operands) { + return mlir::impl::foldCastOp(*this); +} + +/// Fold constants. +OpFoldResult ConstantOp::fold(ArrayRef operands) { return value(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(ArrayRef operands) { + return value(); +} + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(ArrayRef operands) { + auto structAttr = operands.front().dyn_cast_or_null(); + if (!structAttr) + return nullptr; + + size_t elementIndex = index().getZExtValue(); + return structAttr.getValue()[elementIndex]; +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.td b/mlir/examples/toy/Ch7/mlir/ToyCombine.td new file mode 100644 index 0000000000000000000000000000000000000000..e6e33e84d7e8f3e13aea9840f3690029de025d94 --- /dev/null +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.td @@ -0,0 +1,62 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..669bc9dbec21514b8bf3a7971028600b33e8d41a --- /dev/null +++ b/mlir/examples/toy/Ch7/parser/AST.cpp @@ -0,0 +1,271 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(StructLiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + void dump(StructAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + mlir::TypeSwitch(expr) + .Case([&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + if (auto *initVal = varDecl->getInitVal()) + dump(initVal); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + mlir::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a struct literal. +void ASTDumper::dump(StructLiteralExprAST *node) { + INDENT(); + llvm::errs() << "Struct Literal: "; + for (auto &value : node->getValues()) + dump(value.get()); + indent(); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().hasValue()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + if (!type.name.empty()) + llvm::errs() << type.name; + else + mlir::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + mlir::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a struct. +void ASTDumper::dump(StructAST *node) { + INDENT(); + llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n"; + + { + INDENT(); + llvm::errs() << "Variables: [\n"; + for (auto &variable : node->getVariables()) + dump(variable.get()); + indent(); + llvm::errs() << "]\n"; + } +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &record : *node) { + if (FunctionAST *function = llvm::dyn_cast(record.get())) + dump(function); + else if (StructAST *str = llvm::dyn_cast(record.get())) + dump(str); + else + llvm::errs() << "getKind() << ">\n"; + } +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c6afab594e1fc74a3c75df91a32c8ba0e45d4543 --- /dev/null +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -0,0 +1,275 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Analysis/Verifier.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT +}; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect with a few cleanups afterwards. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::createLoopFusionPass()); + optPM.addPass(mlir::createMemRefDataFlowOptPass()); + } + } + + if (isLoweringToLLVM) { + // Finish lowering the toy IR to the LLVM dialect. + pm.addPass(mlir::toy::createLowerToLLVMPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) { + auto llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + + // Register our Dialect with MLIR. + mlir::registerDialect(); + + mlir::MLIRContext context; + mlir::OwningModuleRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; +} diff --git a/mlir/examples/toy/README.md b/mlir/examples/toy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..53912c83abfb228f97a2ad4bac0a93bcd23176a3 --- /dev/null +++ b/mlir/examples/toy/README.md @@ -0,0 +1,7 @@ +# Toy Tutorial + +This contains sample code to support the tutorial on using MLIR for +building a compiler for a simple Toy language. + +See [g3doc/Tutorials/Toy](../../g3doc/Tutorials/Toy) at the root of +the repository for more informations. diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h new file mode 100644 index 0000000000000000000000000000000000000000..5e3e2087f8bf40f749470f9efb46d0e2a6ddd209 --- /dev/null +++ b/mlir/include/mlir-c/Core.h @@ -0,0 +1,109 @@ +/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\ +|* *| +|* Part of the MLIR Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the C interface to MLIR. *| +|* *| +\*===----------------------------------------------------------------------===*/ +#ifndef MLIR_C_CORE_H +#define MLIR_C_CORE_H + +#ifdef __cplusplus +#include +extern "C" { +#else +#include +#endif + +/// Opaque MLIR types. +/// Opaque C type for mlir::MLIRContext*. +typedef void *mlir_context_t; +/// Opaque C type for mlir::Type. +typedef const void *mlir_type_t; +/// Opaque C type for mlir::FuncOp. +typedef void *mlir_func_t; +/// Opaque C type for mlir::Attribute. +typedef const void *mlir_attr_t; + +/// Simple C lists for non-owning mlir Opaque C types. +/// Recommended usage is construction from the `data()` and `size()` of a scoped +/// owning SmallVectorImpl<...> and passing to one of the C functions declared +/// later in this file. +/// Once the function returns and the proper EDSC has been constructed, +/// resources are freed by exiting the scope. +typedef struct { + int64_t *values; + uint64_t n; +} int64_list_t; + +typedef struct { + mlir_type_t *types; + uint64_t n; +} mlir_type_list_t; + +typedef struct { + const char *name; + mlir_attr_t value; +} mlir_named_attr_t; + +typedef struct { + mlir_named_attr_t *list; + uint64_t n; +} mlir_named_attr_list_t; + +/// Minimal C API for exposing EDSCs to Swift, Python and other languages. + +/// Returns an `mlir::MemRefType` of the element type `elemType` and shape +/// `sizes`. +mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, + int64_list_t sizes); + +/// Returns an `mlir::FunctionType` of the element type `elemType` and shape +/// `sizes`. +mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs, + mlir_type_list_t outputs); + +/// Returns an `mlir::IndexType`. +mlir_type_t makeIndexType(mlir_context_t context); + +/// Returns an `mlir::IntegerAttr` of the specified type that contains the given +/// value. +mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value); + +/// Returns an `mlir::BoolAttr` with the given value. +mlir_attr_t makeBoolAttr(mlir_context_t context, bool value); + +/// Returns an `mlir::FloatAttr` with the given value. +mlir_attr_t makeFloatAttr(mlir_context_t context, float value); + +/// Returns an `mlir::StringAttr` with the given value. +mlir_attr_t makeStringAttr(mlir_context_t context, const char *value); + +/// Parses an MLIR type from the string `type` in the given context. Returns a +/// NULL type on error. If non-NULL, `charsRead` will contain the number of +/// characters that were processed by the parser. +mlir_type_t mlirParseType(const char *type, mlir_context_t context, + uint64_t *charsRead); + +/// Returns the arity of `function`. +unsigned getFunctionArity(mlir_func_t function); + +/// Returns the rank of the `function` argument at position `pos`. +/// If the argument is of MemRefType, this returns the rank of the MemRef. +/// Otherwise returns `0`. +/// TODO(ntv): support more than MemRefType and scalar Type. +unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos); + +/// Returns an opaque mlir::Type of the `function` argument at position `pos`. +mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos); + +#ifdef __cplusplus +} // end extern "C" +#endif + +#endif // MLIR_C_CORE_H diff --git a/mlir/include/mlir/ADT/TypeSwitch.h b/mlir/include/mlir/ADT/TypeSwitch.h new file mode 100644 index 0000000000000000000000000000000000000000..2dbc611f557e096157c847e8e37910ed9d2b9638 --- /dev/null +++ b/mlir/include/mlir/ADT/TypeSwitch.h @@ -0,0 +1,176 @@ +//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the TypeSwitch template, which mimics a switch() +// statement whose cases are type names. +// +//===-----------------------------------------------------------------------===/ + +#ifndef MLIR_SUPPORT_TYPESWITCH_H +#define MLIR_SUPPORT_TYPESWITCH_H + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { +namespace detail { + +template class TypeSwitchBase { +public: + TypeSwitchBase(const T &value) : value(value) {} + TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {} + ~TypeSwitchBase() = default; + + /// TypeSwitchBase is not copyable. + TypeSwitchBase(const TypeSwitchBase &) = delete; + void operator=(const TypeSwitchBase &) = delete; + void operator=(TypeSwitchBase &&other) = delete; + + /// Invoke a case on the derived class with multiple case types. + template + DerivedT &Case(CallableT &&caseFn) { + DerivedT &derived = static_cast(*this); + return derived.template Case(caseFn) + .template Case(caseFn); + } + + /// Invoke a case on the derived class, inferring the type of the Case from + /// the first input of the given callable. + /// Note: This inference rules for this overload are very simple: strip + /// pointers and references. + template DerivedT &Case(CallableT &&caseFn) { + using Traits = FunctionTraits>; + using CaseT = std::remove_cv_t>>>; + + DerivedT &derived = static_cast(*this); + return derived.template Case(std::forward(caseFn)); + } + +protected: + /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type + /// `CastT`. + template + using has_dyn_cast_t = + decltype(std::declval().template dyn_cast()); + + /// Attempt to dyn_cast the given `value` to `CastT`. This overload is + /// selected if `value` already has a suitable dyn_cast method. + template + static auto castValue( + ValueT value, + typename std::enable_if_t< + is_detected::value> * = nullptr) { + return value.template dyn_cast(); + } + + /// Attempt to dyn_cast the given `value` to `CastT`. This overload is + /// selected if llvm::dyn_cast should be used. + template + static auto castValue( + ValueT value, + typename std::enable_if_t< + !is_detected::value> * = nullptr) { + return dyn_cast(value); + } + + /// The root value we are switching on. + const T value; +}; +} // end namespace detail + +/// This class implements a switch-like dispatch statement for a value of 'T' +/// using dyn_cast functionality. Each `Case` takes a callable to be invoked +/// if the root value isa, the callable is invoked with the result of +/// dyn_cast() as a parameter. +/// +/// Example: +/// Operation *op = ...; +/// LogicalResult result = TypeSwitch(op) +/// .Case([](ConstantOp op) { ... }) +/// .Default([](Operation *op) { ... }); +/// +template +class TypeSwitch : public detail::TypeSwitchBase, T> { +public: + using BaseT = detail::TypeSwitchBase, T>; + using BaseT::BaseT; + using BaseT::Case; + TypeSwitch(TypeSwitch &&other) = default; + + /// Add a case on the given type. + template + TypeSwitch &Case(CallableT &&caseFn) { + if (result) + return *this; + + // Check to see if CaseT applies to 'value'. + if (auto caseValue = BaseT::template castValue(this->value)) + result = caseFn(caseValue); + return *this; + } + + /// As a default, invoke the given callable within the root value. + template + LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) { + if (result) + return std::move(*result); + return defaultFn(this->value); + } + + LLVM_NODISCARD + operator ResultT() { + assert(result && "Fell off the end of a type-switch"); + return std::move(*result); + } + +private: + /// The pointer to the result of this switch statement, once known, + /// null before that. + Optional result; +}; + +/// Specialization of TypeSwitch for void returning callables. +template +class TypeSwitch + : public detail::TypeSwitchBase, T> { +public: + using BaseT = detail::TypeSwitchBase, T>; + using BaseT::BaseT; + using BaseT::Case; + TypeSwitch(TypeSwitch &&other) = default; + + /// Add a case on the given type. + template + TypeSwitch &Case(CallableT &&caseFn) { + if (foundMatch) + return *this; + + // Check to see if any of the types apply to 'value'. + if (auto caseValue = BaseT::template castValue(this->value)) { + caseFn(caseValue); + foundMatch = true; + } + return *this; + } + + /// As a default, invoke the given callable within the root value. + template void Default(CallableT &&defaultFn) { + if (!foundMatch) + defaultFn(this->value); + } + +private: + /// A flag detailing if we have already found a match. + bool foundMatch = false; +}; +} // end namespace mlir + +#endif // MLIR_SUPPORT_TYPESWITCH_H diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h new file mode 100644 index 0000000000000000000000000000000000000000..d0bcb932c04c78215893a301f8eb5a9e1d5da161 --- /dev/null +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -0,0 +1,131 @@ +//===- AffineAnalysis.h - analyses for affine structures --------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for methods that perform analysis +// involving affine structures (AffineExprStorage, AffineMap, IntegerSet, etc.) +// and other IR structures that in turn use these. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H +#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H + +#include "mlir/IR/Value.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +class AffineApplyOp; +class AffineForOp; +class AffineValueMap; +class FlatAffineConstraints; +class Operation; + +/// Returns in `affineApplyOps`, the sequence of those AffineApplyOp +/// Operations that are reachable via a search starting from `operands` and +/// ending at those operands that are not the result of an AffineApplyOp. +void getReachableAffineApplyOps(ArrayRef operands, + SmallVectorImpl &affineApplyOps); + +/// Builds a system of constraints with dimensional identifiers corresponding to +/// the loop IVs of the forOps appearing in that order. Bounds of the loop are +/// used to add appropriate inequalities. Any symbols founds in the bound +/// operands are added as symbols in the system. Returns failure for the yet +/// unimplemented cases. +// TODO(bondhugula): handle non-unit strides. +LogicalResult getIndexSet(MutableArrayRef forOps, + FlatAffineConstraints *domain); + +/// Encapsulates a memref load or store access information. +struct MemRefAccess { + Value memref; + Operation *opInst; + SmallVector indices; + + /// Constructs a MemRefAccess from a load or store operation. + // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to + // return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess. + explicit MemRefAccess(Operation *opInst); + + // Returns the rank of the memref associated with this access. + unsigned getRank() const; + // Returns true if this access is of a store op. + bool isStore() const; + + /// Populates 'accessMap' with composition of AffineApplyOps reachable from + /// 'indices'. + void getAccessMap(AffineValueMap *accessMap) const; + + /// Equal if both affine accesses can be proved to be equivalent at compile + /// time (considering the memrefs, their respective affine access maps and + /// operands). The equality of access functions + operands is checked by + /// subtracting fully composed value maps, and then simplifying the difference + /// using the expression flattener. + /// TODO: this does not account for aliasing of memrefs. + bool operator==(const MemRefAccess &rhs) const; + bool operator!=(const MemRefAccess &rhs) const { return !(*this == rhs); } +}; + +// DependenceComponent contains state about the direction of a dependence as an +// interval [lb, ub] for an AffineForOp. +// Distance vectors components are represented by the interval [lb, ub] with +// lb == ub. +// Direction vectors components are represented by the interval [lb, ub] with +// lb < ub. Note that ub/lb == None means unbounded. +struct DependenceComponent { + // The AffineForOp Operation associated with this dependence component. + Operation *op; + // The lower bound of the dependence distance. + Optional lb; + // The upper bound of the dependence distance (inclusive). + Optional ub; + DependenceComponent() : lb(llvm::None), ub(llvm::None) {} +}; + +/// Checks whether two accesses to the same memref access the same element. +/// Each access is specified using the MemRefAccess structure, which contains +/// the operation, indices and memref associated with the access. Returns +/// 'NoDependence' if it can be determined conclusively that the accesses do not +/// access the same memref element. If 'allowRAR' is true, will consider +/// read-after-read dependences (typically used by applications trying to +/// optimize input reuse). +// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into +// a single struct. +// TODO(andydavis) Make 'dependenceConstraints' optional arg. +struct DependenceResult { + enum ResultEnum { + HasDependence, // A dependence exists between 'srcAccess' and 'dstAccess'. + NoDependence, // No dependence exists between 'srcAccess' and 'dstAccess'. + Failure, // Dependence check failed due to unsupported cases. + } value; + DependenceResult(ResultEnum v) : value(v) {} +}; + +DependenceResult checkMemrefAccessDependence( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, + SmallVector *dependenceComponents, + bool allowRAR = false); + +/// Utility function that returns true if the provided DependenceResult +/// corresponds to a dependence result. +inline bool hasDependence(DependenceResult result) { + return result.value == DependenceResult::HasDependence; +} + +/// Returns in 'depCompsVec', dependence components for dependences between all +/// load and store ops in loop nest rooted at 'forOp', at loop depths in range +/// [1, maxLoopDepth]. +void getDependenceComponents( + AffineForOp forOp, unsigned maxLoopDepth, + std::vector> *depCompsVec); + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h new file mode 100644 index 0000000000000000000000000000000000000000..47e0ddab5479761a0eab5624e876ef3fd293db8a --- /dev/null +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -0,0 +1,815 @@ +//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Structures for affine/polyhedral analysis of ML functions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_AFFINE_STRUCTURES_H +#define MLIR_ANALYSIS_AFFINE_STRUCTURES_H + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +class AffineApplyOp; +class AffineBound; +class AffineCondition; +class AffineMap; +class AffineForOp; +class IntegerSet; +class MLIRContext; +class Value; +class HyperRectangularSet; +class MemRefType; + +/// A mutable affine map. Its affine expressions are however unique. +struct MutableAffineMap { +public: + MutableAffineMap() {} + MutableAffineMap(AffineMap map); + + ArrayRef getResults() const { return results; } + AffineExpr getResult(unsigned idx) const { return results[idx]; } + void setResult(unsigned idx, AffineExpr result) { results[idx] = result; } + unsigned getNumResults() const { return results.size(); } + unsigned getNumDims() const { return numDims; } + void setNumDims(unsigned d) { numDims = d; } + unsigned getNumSymbols() const { return numSymbols; } + void setNumSymbols(unsigned d) { numSymbols = d; } + MLIRContext *getContext() const { return context; } + + /// Returns true if the idx'th result expression is a multiple of factor. + bool isMultipleOf(unsigned idx, int64_t factor) const; + + /// Resets this MutableAffineMap with 'map'. + void reset(AffineMap map); + + /// Simplify the (result) expressions in this map using analysis (used by + //-simplify-affine-expr pass). + void simplify(); + /// Get the AffineMap corresponding to this MutableAffineMap. Note that an + /// AffineMap will be uniqued and stored in context, while a mutable one + /// isn't. + AffineMap getAffineMap() const; + +private: + // Same meaning as AffineMap's fields. + SmallVector results; + unsigned numDims; + unsigned numSymbols; + /// A pointer to the IR's context to store all newly created + /// AffineExprStorage's. + MLIRContext *context; +}; + +/// A mutable integer set. Its affine expressions are however unique. +struct MutableIntegerSet { +public: + MutableIntegerSet(IntegerSet set, MLIRContext *context); + + /// Create a universal set (no constraints). + MutableIntegerSet(unsigned numDims, unsigned numSymbols, + MLIRContext *context); + + unsigned getNumDims() const { return numDims; } + unsigned getNumSymbols() const { return numSymbols; } + unsigned getNumConstraints() const { return constraints.size(); } + + void clear() { + constraints.clear(); + eqFlags.clear(); + } + +private: + unsigned numDims; + unsigned numSymbols; + + SmallVector constraints; + SmallVector eqFlags; +}; + +/// An AffineValueMap is an affine map plus its ML value operands and +/// results for analysis purposes. The structure is still a tree form that is +/// same as that of an affine map or an AffineApplyOp. However, its operands, +/// results, and its map can themselves change as a result of +/// substitutions, simplifications, and other analysis. +// An affine value map can readily be constructed from an AffineApplyOp, or an +// AffineBound of a AffineForOp. It can be further transformed, substituted +// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and +// destroyed during analysis. Only the AffineMap expressions that are pointed by +// them are unique'd. An affine value map, and the operations on it, maintain +// the invariant that operands are always positionally aligned with the +// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. +// TODO(bondhugula): Some of these classes could go into separate files. +class AffineValueMap { +public: + // Creates an empty AffineValueMap (users should call 'reset' to reset map + // and operands). + AffineValueMap() {} + AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); + + explicit AffineValueMap(AffineApplyOp applyOp); + explicit AffineValueMap(AffineBound bound); + + ~AffineValueMap(); + + // Resets this AffineValueMap with 'map', 'operands', and 'results'. + void reset(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); + + /// Return the value map that is the difference of value maps 'a' and 'b', + /// represented as an affine map and its operands. The output map + operands + /// are canonicalized and simplified. + static void difference(const AffineValueMap &a, const AffineValueMap &b, + AffineValueMap *res); + + /// Return true if the idx^th result can be proved to be a multiple of + /// 'factor', false otherwise. + inline bool isMultipleOf(unsigned idx, int64_t factor) const; + + /// Return true if the idx^th result depends on 'value', false otherwise. + bool isFunctionOf(unsigned idx, Value value) const; + + /// Return true if the result at 'idx' is a constant, false + /// otherwise. + bool isConstant(unsigned idx) const; + + /// Return true if this is an identity map. + bool isIdentity() const; + + void setResult(unsigned i, AffineExpr e) { map.setResult(i, e); } + AffineExpr getResult(unsigned i) { return map.getResult(i); } + inline unsigned getNumOperands() const { return operands.size(); } + inline unsigned getNumDims() const { return map.getNumDims(); } + inline unsigned getNumSymbols() const { return map.getNumSymbols(); } + inline unsigned getNumResults() const { return map.getNumResults(); } + + Value getOperand(unsigned i) const; + ArrayRef getOperands() const; + AffineMap getAffineMap() const; + +private: + // A mutable affine map. + MutableAffineMap map; + + // TODO: make these trailing objects? + /// The SSA operands binding to the dim's and symbols of 'map'. + SmallVector operands; + /// The SSA results binding to the results of 'map'. + SmallVector results; +}; + +/// An IntegerValueSet is an integer set plus its operands. +// Both, the integer set being pointed to and the operands can change during +// analysis, simplification, and transformation. +class IntegerValueSet { + /// Constructs an integer value set from an affine value map. + // This will lead to a single equality in 'set'. + explicit IntegerValueSet(const AffineValueMap &avm); + + /// Returns true if this integer set is determined to be empty. Emptiness is + /// checked by by eliminating identifiers successively (through either + /// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial + /// invalid constraint check. Returns 'true' if the constraint system is found + /// to be empty; false otherwise. This method is exact for rational spaces but + /// not integer spaces - thus, if it returns true, the set is provably integer + /// empty as well, but if it returns false, it doesn't necessarily mean an + /// integer point exists in it. This method also returns false where an + /// explosion of constraints is detected - due to the super-exponential + /// worse-case complexity of Fourier-Motzkin elimination (rare for realistic + /// problem cases but possible for artificial adversarial or improperly + // constructed ones), this method returns false conservatively. + bool isEmpty() const; + + bool getNumDims() const { return set.getNumDims(); } + bool getNumSymbols() const { return set.getNumSymbols(); } + +private: + // The set pointed to may itself change unlike in IR structures like + // 'AffineCondition'. + MutableIntegerSet set; + /// The SSA operands binding to the dim's and symbols of 'set'. + SmallVector operands; +}; + +/// A flat list of affine equalities and inequalities in the form. +/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0 +/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0 +/// +/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer +/// for equalities and one for inequalities). The size of each buffer is +/// numReservedCols * number of inequalities (or equalities). The reserved size +/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A +/// coefficient (r, c) lives at the location numReservedCols * r + c in the +/// buffer. The extra space between getNumCols() and numReservedCols exists to +/// prevent frequent movement of data when adding columns, especially at the +/// end. +/// +/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers, +/// symbolic identifiers, and local identifiers. The local identifiers +/// correspond to local/internal variables created when converting from +/// AffineExpr's containing mod's and div's; they are thus needed to increase +/// representational power. Each local identifier is always (by construction) a +/// floordiv of a pure add/mul affine function of dimensional, symbolic, and +/// other local identifiers, in a non-mutually recursive way. Hence, every local +/// identifier can ultimately always be recovered as an affine function of +/// dimensional and symbolic identifiers (involving floordiv's); note however +/// that some floordiv combinations are converted to mod's by AffineExpr +/// construction. +/// +class FlatAffineConstraints { +public: + enum IdKind { Dimension, Symbol, Local }; + + /// Constructs a constraint system reserving memory for the specified number + /// of constraints and identifiers.. + FlatAffineConstraints(unsigned numReservedInequalities, + unsigned numReservedEqualities, + unsigned numReservedCols, unsigned numDims = 0, + unsigned numSymbols = 0, unsigned numLocals = 0, + ArrayRef> idArgs = {}) + : numReservedCols(numReservedCols), numDims(numDims), + numSymbols(numSymbols) { + assert(numReservedCols >= numDims + numSymbols + 1); + assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); + equalities.reserve(numReservedCols * numReservedEqualities); + inequalities.reserve(numReservedCols * numReservedInequalities); + numIds = numDims + numSymbols + numLocals; + ids.reserve(numReservedCols); + if (idArgs.empty()) + ids.resize(numIds, None); + else + ids.append(idArgs.begin(), idArgs.end()); + } + + /// Constructs a constraint system with the specified number of + /// dimensions and symbols. + FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, + unsigned numLocals = 0, + ArrayRef> idArgs = {}) + : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), + numSymbols(numSymbols) { + assert(numReservedCols >= numDims + numSymbols + 1); + assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); + numIds = numDims + numSymbols + numLocals; + ids.reserve(numIds); + if (idArgs.empty()) + ids.resize(numIds, None); + else + ids.append(idArgs.begin(), idArgs.end()); + } + + explicit FlatAffineConstraints(const HyperRectangularSet &set); + + /// Create a flat affine constraint system from an AffineValueMap or a list of + /// these. The constructed system will only include equalities. + // TODO(bondhugula) + explicit FlatAffineConstraints(const AffineValueMap &avm); + explicit FlatAffineConstraints(ArrayRef avmRef); + + /// Creates an affine constraint system from an IntegerSet. + explicit FlatAffineConstraints(IntegerSet set); + + /// Create an affine constraint system from an IntegerValueSet. + // TODO(bondhugula) + explicit FlatAffineConstraints(const IntegerValueSet &set); + + FlatAffineConstraints(const FlatAffineConstraints &other); + + FlatAffineConstraints(ArrayRef avmRef, + IntegerSet set); + + FlatAffineConstraints(const MutableAffineMap &map); + + ~FlatAffineConstraints() {} + + // Clears any existing data and reserves memory for the specified constraints. + void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, + unsigned numReservedCols, unsigned numDims, unsigned numSymbols, + unsigned numLocals = 0, ArrayRef idArgs = {}); + + void reset(unsigned numDims = 0, unsigned numSymbols = 0, + unsigned numLocals = 0, ArrayRef idArgs = {}); + + /// Appends constraints from 'other' into this. This is equivalent to an + /// intersection with no simplification of any sort attempted. + void append(const FlatAffineConstraints &other); + + // Checks for emptiness by performing variable elimination on all identifiers, + // running the GCD test on each equality constraint, and checking for invalid + // constraints. + // Returns true if the GCD test fails for any equality, or if any invalid + // constraints are discovered on any row. Returns false otherwise. + bool isEmpty() const; + + // Runs the GCD test on all equality constraints. Returns 'true' if this test + // fails on any equality. Returns 'false' otherwise. + // This test can be used to disprove the existence of a solution. If it + // returns true, no integer solution to the equality constraints can exist. + bool isEmptyByGCDTest() const; + + // Clones this object. + std::unique_ptr clone() const; + + /// Returns the value at the specified equality row and column. + inline int64_t atEq(unsigned i, unsigned j) const { + return equalities[i * numReservedCols + j]; + } + inline int64_t &atEq(unsigned i, unsigned j) { + return equalities[i * numReservedCols + j]; + } + + inline int64_t atIneq(unsigned i, unsigned j) const { + return inequalities[i * numReservedCols + j]; + } + + inline int64_t &atIneq(unsigned i, unsigned j) { + return inequalities[i * numReservedCols + j]; + } + + /// Returns the number of columns in the constraint system. + inline unsigned getNumCols() const { return numIds + 1; } + + inline unsigned getNumEqualities() const { + assert(equalities.size() % numReservedCols == 0 && + "inconsistent equality buffer size"); + return equalities.size() / numReservedCols; + } + + inline unsigned getNumInequalities() const { + assert(inequalities.size() % numReservedCols == 0 && + "inconsistent inequality buffer size"); + return inequalities.size() / numReservedCols; + } + + inline unsigned getNumReservedEqualities() const { + return equalities.capacity() / numReservedCols; + } + + inline unsigned getNumReservedInequalities() const { + return inequalities.capacity() / numReservedCols; + } + + inline ArrayRef getEquality(unsigned idx) const { + return ArrayRef(&equalities[idx * numReservedCols], getNumCols()); + } + + inline ArrayRef getInequality(unsigned idx) const { + return ArrayRef(&inequalities[idx * numReservedCols], + getNumCols()); + } + + AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); + + /// Adds constraints (lower and upper bounds) for the specified 'affine.for' + /// operation's Value using IR information stored in its bound maps. The + /// right identifier is first looked up using forOp's Value. Asserts if the + /// Value corresponding to the 'affine.for' operation isn't found in the + /// constraint system. Returns failure for the yet unimplemented/unsupported + /// cases. Any new identifiers that are found in the bound operands of the + /// 'affine.for' operation are added as trailing identifiers (either + /// dimensional or symbolic depending on whether the operand is a valid + /// symbol). + // TODO(bondhugula): add support for non-unit strides. + LogicalResult addAffineForOpDomain(AffineForOp forOp); + + /// Adds a lower or an upper bound for the identifier at the specified + /// position with constraints being drawn from the specified bound map and + /// operands. If `eq` is true, add a single equality equal to the bound map's + /// first result expr. + LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, + ArrayRef operands, bool eq, + bool lower = true); + + /// Computes the lower and upper bounds of the first 'num' dimensional + /// identifiers (starting at 'offset') as an affine map of the remaining + /// identifiers (dimensional and symbolic). This method is able to detect + /// identifiers as floordiv's and mod's of affine expressions of other + /// identifiers with respect to (positive) constants. Sets bound map to a + /// null AffineMap if such a bound can't be found (or yet unimplemented). + void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps); + + /// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper + /// bounds in 'ubMaps' to each identifier in the constraint system which has + /// a value in 'values'. Note that both lower/upper bounds share the same + /// operand list 'operands'. + /// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'. + /// Note that both lower/upper bounds use operands from 'operands'. + LogicalResult addSliceBounds(ArrayRef values, + ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands); + + // Adds an inequality (>= 0) from the coefficients specified in inEq. + void addInequality(ArrayRef inEq); + // Adds an equality from the coefficients specified in eq. + void addEquality(ArrayRef eq); + + /// Adds a constant lower bound constraint for the specified identifier. + void addConstantLowerBound(unsigned pos, int64_t lb); + /// Adds a constant upper bound constraint for the specified identifier. + void addConstantUpperBound(unsigned pos, int64_t ub); + + /// Adds a new local identifier as the floordiv of an affine function of other + /// identifiers, the coefficients of which are provided in 'dividend' and with + /// respect to a positive constant 'divisor'. Two constraints are added to the + /// system to capture equivalence with the floordiv: + /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. + void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); + + /// Adds a constant lower bound constraint for the specified expression. + void addConstantLowerBound(ArrayRef expr, int64_t lb); + /// Adds a constant upper bound constraint for the specified expression. + void addConstantUpperBound(ArrayRef expr, int64_t ub); + + /// Sets the identifier at the specified position to a constant. + void setIdToConstant(unsigned pos, int64_t val); + + /// Sets the identifier corresponding to the specified Value id to a + /// constant. Asserts if the 'id' is not found. + void setIdToConstant(Value id, int64_t val); + + /// Looks up the position of the identifier with the specified Value. Returns + /// true if found (false otherwise). `pos' is set to the (column) position of + /// the identifier. + bool findId(Value id, unsigned *pos) const; + + /// Returns true if an identifier with the specified Value exists, false + /// otherwise. + bool containsId(Value id) const; + + // Add identifiers of the specified kind - specified positions are relative to + // the kind of identifier. The coefficient column corresponding to the added + // identifier is initialized to zero. 'id' is the Value corresponding to the + // identifier that can optionally be provided. + void addDimId(unsigned pos, Value id = nullptr); + void addSymbolId(unsigned pos, Value id = nullptr); + void addLocalId(unsigned pos); + void addId(IdKind kind, unsigned pos, Value id = nullptr); + + /// Add the specified values as a dim or symbol id depending on its nature, if + /// it already doesn't exist in the system. `id' has to be either a terminal + /// symbol or a loop IV, i.e., it cannot be the result affine.apply of any + /// symbols or loop IVs. The identifier is added to the end of the existing + /// dims or symbols. Additional information on the identifier is extracted + /// from the IR and added to the constraint system. + void addInductionVarOrTerminalSymbol(Value id); + + /// Composes the affine value map with this FlatAffineConstrains, adding the + /// results of the map as dimensions at the front [0, vMap->getNumResults()) + /// and with the dimensions set to the equalities specified by the value map. + /// Returns failure if the composition fails (when vMap is a semi-affine map). + /// The vMap's operand Value's are used to look up the right positions in + /// the FlatAffineConstraints with which to associate. The dimensional and + /// symbolic operands of vMap should match 1:1 (in the same order) with those + /// of this constraint system, but the latter could have additional trailing + /// operands. + LogicalResult composeMap(const AffineValueMap *vMap); + + /// Composes an affine map whose dimensions match one to one to the + /// dimensions of this FlatAffineConstraints. The results of the map 'other' + /// are added as the leading dimensions of this constraint system. Returns + /// failure if 'other' is a semi-affine map. + LogicalResult composeMatchingMap(AffineMap other); + + /// Projects out (aka eliminates) 'num' identifiers starting at position + /// 'pos'. The resulting constraint system is the shadow along the dimensions + /// that still exist. This method may not always be integer exact. + // TODO(bondhugula): deal with integer exactness when necessary - can return a + // value to mark exactness for example. + void projectOut(unsigned pos, unsigned num); + inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + + /// Projects out the identifier that is associate with Value . + void projectOut(Value id); + + void removeId(IdKind idKind, unsigned pos); + void removeId(unsigned pos); + + void removeDim(unsigned pos); + + void removeEquality(unsigned pos); + void removeInequality(unsigned pos); + + /// Changes the partition between dimensions and symbols. Depending on the new + /// symbol count, either a chunk of trailing dimensional identifiers becomes + /// symbols, or some of the leading symbols become dimensions. + void setDimSymbolSeparation(unsigned newSymbolCount); + + /// Changes all symbol identifiers which are loop IVs to dim identifiers. + void convertLoopIVSymbolsToDims(); + + /// Sets the specified identifier to a constant and removes it. + void setAndEliminate(unsigned pos, int64_t constVal); + + /// Tries to fold the specified identifier to a constant using a trivial + /// equality detection; if successful, the constant is substituted for the + /// identifier everywhere in the constraint system and then removed from the + /// system. + LogicalResult constantFoldId(unsigned pos); + + /// This method calls constantFoldId for the specified range of identifiers, + /// 'num' identifiers starting at position 'pos'. + void constantFoldIdRange(unsigned pos, unsigned num); + + /// Updates the constraints to be the smallest bounding (enclosing) box that + /// contains the points of 'this' set and that of 'other', with the symbols + /// being treated specially. For each of the dimensions, the min of the lower + /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed + /// to determine such a bounding box. `other' is expected to have the same + /// dimensional identifiers as this constraint system (in the same order). + /// + /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the + /// output is {0 <= d0 <= 192}. + /// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 + + /// 9}, output = {s0 + 1 <= d0 <= s0 + 20}. + /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1 + /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. + LogicalResult unionBoundingBox(const FlatAffineConstraints &other); + + /// Returns 'true' if this constraint system and 'other' are in the same + /// space, i.e., if they are associated with the same set of identifiers, + /// appearing in the same order. Returns 'false' otherwise. + bool areIdsAlignedWithOther(const FlatAffineConstraints &other); + + /// Merge and align the identifiers of 'this' and 'other' starting at + /// 'offset', so that both constraint systems get the union of the contained + /// identifiers that is dimension-wise and symbol-wise unique; both + /// constraint systems are updated so that they have the union of all + /// identifiers, with this's original identifiers appearing first followed by + /// any of other's identifiers that didn't appear in 'this'. Local + /// identifiers of each system are by design separate/local and are placed + /// one after other (this's followed by other's). + // Eg: Input: 'this' has ((%i %j) [%M %N]) + // 'other' has (%k, %j) [%P, %N, %M]) + // Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P] + // + void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other); + + unsigned getNumConstraints() const { + return getNumInequalities() + getNumEqualities(); + } + inline unsigned getNumIds() const { return numIds; } + inline unsigned getNumDimIds() const { return numDims; } + inline unsigned getNumSymbolIds() const { return numSymbols; } + inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } + inline unsigned getNumLocalIds() const { + return numIds - numDims - numSymbols; + } + + inline ArrayRef> getIds() const { + return {ids.data(), ids.size()}; + } + inline MutableArrayRef> getIds() { + return {ids.data(), ids.size()}; + } + + /// Returns the optional Value corresponding to the pos^th identifier. + inline Optional getId(unsigned pos) const { return ids[pos]; } + inline Optional &getId(unsigned pos) { return ids[pos]; } + + /// Returns the Value associated with the pos^th identifier. Asserts if + /// no Value identifier was associated. + inline Value getIdValue(unsigned pos) const { + assert(ids[pos].hasValue() && "identifier's Value not set"); + return ids[pos].getValue(); + } + + /// Returns the Values associated with identifiers in range [start, end). + /// Asserts if no Value was associated with one of these identifiers. + void getIdValues(unsigned start, unsigned end, + SmallVectorImpl *values) const { + assert((start < numIds || start == end) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + values->clear(); + values->reserve(end - start); + for (unsigned i = start; i < end; i++) { + values->push_back(getIdValue(i)); + } + } + inline void getAllIdValues(SmallVectorImpl *values) const { + getIdValues(0, numIds, values); + } + + /// Sets Value associated with the pos^th identifier. + inline void setIdValue(unsigned pos, Value val) { + assert(pos < numIds && "invalid id position"); + ids[pos] = val; + } + /// Sets Values associated with identifiers in the range [start, end). + void setIdValues(unsigned start, unsigned end, ArrayRef values) { + assert((start < numIds || end == start) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + assert(values.size() == end - start); + for (unsigned i = start; i < end; ++i) + ids[i] = values[i - start]; + } + + /// Clears this list of constraints and copies other into it. + void clearAndCopyFrom(const FlatAffineConstraints &other); + + /// Returns the smallest known constant bound for the extent of the specified + /// identifier (pos^th), i.e., the smallest known constant that is greater + /// than or equal to 'exclusive upper bound' - 'lower bound' of the + /// identifier. Returns None if it's not a constant. This method employs + /// trivial (low complexity / cost) checks and detection. Symbolic identifiers + /// are treated specially, i.e., it looks for constant differences between + /// affine expressions involving only the symbolic identifiers. See comments + /// at function definition for examples. 'lb' and 'lbDivisor', if provided, + /// are used to express the lower bound associated with the constant + /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., + /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three + /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. + Optional + getConstantBoundOnDimSize(unsigned pos, + SmallVectorImpl *lb = nullptr, + int64_t *lbFloorDivisor = nullptr, + SmallVectorImpl *ub = nullptr) const; + + /// Returns the constant lower bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantLowerBound(unsigned pos) const; + + /// Returns the constant upper bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantUpperBound(unsigned pos) const; + + /// Gets the lower and upper bound of the pos^th identifier treating + /// [0, offset) U [offset + num, symStartPos) as dimensions and + /// [symStartPos, getNumDimAndSymbolIds) as symbols. The returned + /// multi-dimensional maps in the pair represent the max and min of + /// potentially multiple affine expressions. The upper bound is exclusive. + /// 'localExprs' holds pre-computed AffineExpr's for all local identifiers in + /// the system. + std::pair + getLowerAndUpperBound(unsigned pos, unsigned offset, unsigned num, + unsigned symStartPos, ArrayRef localExprs, + MLIRContext *context) const; + + /// Returns true if the set can be trivially detected as being + /// hyper-rectangular on the specified contiguous set of identifiers. + bool isHyperRectangular(unsigned pos, unsigned num) const; + + /// Removes duplicate constraints, trivially true constraints, and constraints + /// that can be detected as redundant as a result of differing only in their + /// constant term part. A constraint of the form >= 0 + /// is considered trivially true. This method is a linear time method on the + /// constraints, does a single scan, and updates in place. + void removeTrivialRedundancy(); + + /// A more expensive check to detect redundant inequalities thatn + /// removeTrivialRedundancy. + void removeRedundantInequalities(); + + // Removes all equalities and inequalities. + void clearConstraints(); + + void print(raw_ostream &os) const; + void dump() const; + +private: + /// Returns false if the fields corresponding to various identifier counts, or + /// equality/inequality buffer sizes aren't consistent; true otherwise. This + /// is meant to be used within an assert internally. + bool hasConsistentState() const; + + /// Checks all rows of equality/inequality constraints for trivial + /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced + /// after elimination. Returns 'true' if an invalid constraint is found; + /// 'false'otherwise. + bool hasInvalidConstraint() const; + + /// Returns the constant lower bound bound if isLower is true, and the upper + /// bound if isLower is false. + template + Optional computeConstantLowerOrUpperBound(unsigned pos); + + // Eliminates a single identifier at 'position' from equality and inequality + // constraints. Returns 'success' if the identifier was eliminated, and + // 'failure' otherwise. + inline LogicalResult gaussianEliminateId(unsigned position) { + return success(gaussianEliminateIds(position, position + 1) == 1); + } + + // Eliminates identifiers from equality and inequality constraints + // in column range [posStart, posLimit). + // Returns the number of variables eliminated. + unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit); + + /// Eliminates identifier at the specified position using Fourier-Motzkin + /// variable elimination, but uses Gaussian elimination if there is an + /// equality involving that identifier. If the result of the elimination is + /// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is + /// set to true, a potential under approximation (subset) of the rational + /// shadow / exact integer shadow is computed. + // See implementation comments for more details. + void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false, + bool *isResultIntegerExact = nullptr); + + /// Tightens inequalities given that we are dealing with integer spaces. This + /// is similar to the GCD test but applied to inequalities. The constant term + /// can be reduced to the preceding multiple of the GCD of the coefficients, + /// i.e., + /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a + /// fast method (linear in the number of coefficients). + void GCDTightenInequalities(); + + /// Normalized each constraints by the GCD of its coefficients. + void normalizeConstraintsByGCD(); + + /// Removes identifiers in column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(unsigned idStart, unsigned idLimit); + + /// Coefficients of affine equalities (in == 0 form). + SmallVector equalities; + + /// Coefficients of affine inequalities (in >= 0 form). + SmallVector inequalities; + + /// Number of columns reserved. Actual ones in used are returned by + /// getNumCols(). + unsigned numReservedCols; + + /// Total number of identifiers. + unsigned numIds; + + /// Number of identifiers corresponding to real dimensions. + unsigned numDims; + + /// Number of identifiers corresponding to symbols (unknown but constant for + /// analysis). + unsigned numSymbols; + + /// Values corresponding to the (column) identifiers of this constraint + /// system appearing in the order the identifiers correspond to columns. + /// Temporary ones or those that aren't associated to any Value are set to + /// None. + SmallVector, 8> ids; + + /// A parameter that controls detection of an unrealistic number of + /// constraints. If the number of constraints is this many times the number of + /// variables, we consider such a system out of line with the intended use + /// case of FlatAffineConstraints. + // The rationale for 32 is that in the typical simplest of cases, an + // identifier is expected to have one lower bound and one upper bound + // constraint. With a level of tiling or a connection to another identifier + // through a div or mod, an extra pair of bounds gets added. As a limit, we + // don't expect an identifier to have more than 32 lower/upper/equality + // constraints. This is conservatively set low and can be raised if needed. + constexpr static unsigned kExplosionFactor = 32; +}; + +/// Simplify an affine expression by flattening and some amount of +/// simple analysis. This has complexity linear in the number of nodes in +/// 'expr'. Returns the simplified expression, which is the same as the input +/// expression if it can't be simplified. +AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols); + +/// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the +/// dimensions, symbols, and additional variables that represent floor divisions +/// of dimensions, symbols, and in turn other floor divisions. Returns failure +/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). +/// 'cst' contains constraints that connect newly introduced local identifiers +/// to existing dimensional and symbolic identifiers. See documentation for +/// AffineExprFlattener on how mod's and div's are flattened. +LogicalResult getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *cst = nullptr); + +/// Flattens the result expressions of the map to their corresponding flattened +/// forms and set in 'flattenedExprs'. Returns failure if any expression in the +/// map could not be flattened (i.e., semi-affine is not yet handled). 'cst' +/// contains constraints that connect newly introduced local identifiers to +/// existing dimensional and / symbolic identifiers. See documentation for +/// AffineExprFlattener on how mod's and div's are flattened. For all affine +/// expressions that share the same operands (like those of an affine map), this +/// method should be used instead of repeatedly calling getFlattenedAffineExpr +/// since local variables added to deal with div's and mod's will be reused +/// across expressions. +LogicalResult +getFlattenedAffineExprs(AffineMap map, + std::vector> *flattenedExprs, + FlatAffineConstraints *cst = nullptr); +LogicalResult +getFlattenedAffineExprs(IntegerSet set, + std::vector> *flattenedExprs, + FlatAffineConstraints *cst = nullptr); + +} // end namespace mlir. + +#endif // MLIR_ANALYSIS_AFFINE_STRUCTURES_H diff --git a/mlir/include/mlir/Analysis/CMakeLists.txt b/mlir/include/mlir/Analysis/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3d9a7ed369799f04de873b23e532c4bf7fbdb74a --- /dev/null +++ b/mlir/include/mlir/Analysis/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS CallInterfaces.td) +mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRCallOpInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td) +mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTypeInferOpInterfaceIncGen) diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h new file mode 100644 index 0000000000000000000000000000000000000000..8f954161921edb645abc78acc9d57632ad0f01d3 --- /dev/null +++ b/mlir/include/mlir/Analysis/CallGraph.h @@ -0,0 +1,253 @@ +//===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an analysis for computing the multi-level callgraph from a +// given top-level operation. This nodes within this callgraph are defined by +// the `CallOpInterface` and `CallableOpInterface` operation interfaces defined +// in CallInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_CALLGRAPH_H +#define MLIR_ANALYSIS_CALLGRAPH_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir { +struct CallInterfaceCallable; +class Operation; +class Region; + +//===----------------------------------------------------------------------===// +// CallGraphNode +//===----------------------------------------------------------------------===// + +/// This class represents a single callable in the callgraph. Aside from the +/// external node, each node represents a callable node in the graph and +/// contains a valid corresponding Region. The external node is a virtual node +/// used to represent external edges into, and out of, the callgraph. +class CallGraphNode { +public: + /// This class represents a directed edge between two nodes in the callgraph. + class Edge { + enum class Kind { + // An 'Abstract' edge represents an opaque, non-operation, reference + // between this node and the target. Edges of this type are only valid + // from the external node, as there is no valid connection to an operation + // in the module. + Abstract, + + // A 'Call' edge represents a direct reference to the target node via a + // call-like operation within the callable region of this node. + Call, + + // A 'Child' edge is used when the region of target node is defined inside + // of the callable region of this node. This means that the region of this + // node is an ancestor of the region for the target node. As such, this + // edge cannot be used on the 'external' node. + Child, + }; + + public: + /// Returns if this edge represents an `Abstract` edge. + bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; } + + /// Returns if this edge represents a `Call` edge. + bool isCall() const { return targetAndKind.getInt() == Kind::Call; } + + /// Returns if this edge represents a `Child` edge. + bool isChild() const { return targetAndKind.getInt() == Kind::Child; } + + /// Returns the target node for this edge. + CallGraphNode *getTarget() const { return targetAndKind.getPointer(); } + + bool operator==(const Edge &edge) const { + return targetAndKind == edge.targetAndKind; + } + + private: + Edge(CallGraphNode *node, Kind kind) : targetAndKind(node, kind) {} + explicit Edge(llvm::PointerIntPair targetAndKind) + : targetAndKind(targetAndKind) {} + + /// The target node of this edge, as well as the edge kind. + llvm::PointerIntPair targetAndKind; + + // Provide access to the constructor and Kind. + friend class CallGraphNode; + }; + + /// Returns if this node is the external node. + bool isExternal() const; + + /// Returns the callable region this node represents. This can only be called + /// on non-external nodes. + Region *getCallableRegion() const; + + /// Adds an abstract reference edge to the given node. An abstract edge does + /// not come from any observable operations, so this is only valid on the + /// external node. + void addAbstractEdge(CallGraphNode *node); + + /// Add an outgoing call edge from this node. + void addCallEdge(CallGraphNode *node); + + /// Adds a reference edge to the given child node. + void addChildEdge(CallGraphNode *child); + + /// Iterator over the outgoing edges of this node. + using iterator = SmallVectorImpl::const_iterator; + iterator begin() const { return edges.begin(); } + iterator end() const { return edges.end(); } + + /// Returns true if this node has any child edges. + bool hasChildren() const; + +private: + /// DenseMap info for callgraph edges. + struct EdgeKeyInfo { + using BaseInfo = + DenseMapInfo>; + + static Edge getEmptyKey() { return Edge(BaseInfo::getEmptyKey()); } + static Edge getTombstoneKey() { return Edge(BaseInfo::getTombstoneKey()); } + static unsigned getHashValue(const Edge &edge) { + return BaseInfo::getHashValue(edge.targetAndKind); + } + static bool isEqual(const Edge &lhs, const Edge &rhs) { return lhs == rhs; } + }; + + CallGraphNode(Region *callableRegion) : callableRegion(callableRegion) {} + + /// Add an edge to 'node' with the given kind. + void addEdge(CallGraphNode *node, Edge::Kind kind); + + /// The callable region defines the boundary of the call graph node. This is + /// the region referenced by 'call' operations. This is at a per-region + /// boundary as operations may define multiple callable regions. + Region *callableRegion; + + /// A set of out-going edges from this node to other nodes in the graph. + llvm::SetVector, + llvm::SmallDenseSet> + edges; + + // Provide access to private methods. + friend class CallGraph; +}; + +//===----------------------------------------------------------------------===// +// CallGraph +//===----------------------------------------------------------------------===// + +class CallGraph { + using NodeMapT = llvm::MapVector>; + + /// This class represents an iterator over the internal call graph nodes. This + /// class unwraps the map iterator to access the raw node. + class NodeIterator final + : public llvm::mapped_iterator< + NodeMapT::const_iterator, + CallGraphNode *(*)(const NodeMapT::value_type &)> { + static CallGraphNode *unwrap(const NodeMapT::value_type &value) { + return value.second.get(); + } + + public: + /// Initializes the result type iterator to the specified result iterator. + NodeIterator(NodeMapT::const_iterator it) + : llvm::mapped_iterator< + NodeMapT::const_iterator, + CallGraphNode *(*)(const NodeMapT::value_type &)>(it, &unwrap) {} + }; + +public: + CallGraph(Operation *op); + + /// Get or add a call graph node for the given region. `parentNode` + /// corresponds to the direct node in the callgraph that contains the parent + /// operation of `region`, or nullptr if there is no parent node. + CallGraphNode *getOrAddNode(Region *region, CallGraphNode *parentNode); + + /// Lookup a call graph node for the given region, or nullptr if none is + /// registered. + CallGraphNode *lookupNode(Region *region) const; + + /// Return the callgraph node representing the indirect-external callee. + CallGraphNode *getExternalNode() const { + return const_cast(&externalNode); + } + + /// Resolve the callable for given callee to a node in the callgraph, or the + /// external node if a valid node was not resolved. 'from' provides an anchor + /// for symbol table lookups, and is only required if the callable is a symbol + /// reference. + CallGraphNode *resolveCallable(CallInterfaceCallable callable, + Operation *from = nullptr) const; + + /// An iterator over the nodes of the graph. + using iterator = NodeIterator; + iterator begin() const { return nodes.begin(); } + iterator end() const { return nodes.end(); } + + /// Dump the graph in a human readable format. + void dump() const; + void print(raw_ostream &os) const; + +private: + /// The set of nodes within the callgraph. + NodeMapT nodes; + + /// A special node used to indicate an external edges. + CallGraphNode externalNode; +}; + +} // end namespace mlir + +namespace llvm { +// Provide graph traits for traversing call graphs using standard graph +// traversals. +template <> struct GraphTraits { + using NodeRef = mlir::CallGraphNode *; + static NodeRef getEntryNode(NodeRef node) { return node; } + + static NodeRef unwrap(const mlir::CallGraphNode::Edge &edge) { + return edge.getTarget(); + } + + // ChildIteratorType/begin/end - Allow iteration over all nodes in the graph. + using ChildIteratorType = + mapped_iterator; + static ChildIteratorType child_begin(NodeRef node) { + return {node->begin(), &unwrap}; + } + static ChildIteratorType child_end(NodeRef node) { + return {node->end(), &unwrap}; + } +}; + +template <> +struct GraphTraits + : public GraphTraits { + /// The entry node into the graph is the external node. + static NodeRef getEntryNode(const mlir::CallGraph *cg) { + return cg->getExternalNode(); + } + + // nodes_iterator/begin/end - Allow iteration over all nodes in the graph + using nodes_iterator = mlir::CallGraph::iterator; + static nodes_iterator nodes_begin(mlir::CallGraph *cg) { return cg->begin(); } + static nodes_iterator nodes_end(mlir::CallGraph *cg) { return cg->end(); } +}; +} // end namespace llvm + +#endif // MLIR_ANALYSIS_CALLGRAPH_H diff --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h new file mode 100644 index 0000000000000000000000000000000000000000..b5870bac1429178c4680848942fc6db622ce1663 --- /dev/null +++ b/mlir/include/mlir/Analysis/CallInterfaces.h @@ -0,0 +1,31 @@ +//===- CallInterfaces.h - Call Interfaces for MLIR --------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the call interfaces defined in +// `CallInterfaces.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_CALLINTERFACES_H +#define MLIR_ANALYSIS_CALLINTERFACES_H + +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/PointerUnion.h" + +namespace mlir { + +/// A callable is either a symbol, or an SSA value, that is referenced by a +/// call-like operation. This represents the destination of the call. +struct CallInterfaceCallable : public PointerUnion { + using PointerUnion::PointerUnion; +}; + +#include "mlir/Analysis/CallInterfaces.h.inc" +} // end namespace mlir + +#endif // MLIR_ANALYSIS_CALLINTERFACES_H diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td new file mode 100644 index 0000000000000000000000000000000000000000..3e5b599baf8faf283fcd5fa67a42fc9586d650aa --- /dev/null +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -0,0 +1,84 @@ +//===- CallInterfaces.td - Call Interfaces for ops -*- tablegen ---------*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to define information +// related to call-like and callable operations. Each of which are defined along +// with the respective interface below. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CALLINTERFACES +#define MLIR_CALLINTERFACES + +include "mlir/IR/OpBase.td" + +// `CallInterfaceCallable`: This is a type used to represent a single callable +// region. A callable is either a symbol, or an SSA value, that is referenced by +// a call-like operation. This represents the destination of the call. + +/// Interface for call-like operations. +def CallOpInterface : OpInterface<"CallOpInterface"> { + let description = [{ + A call-like operation is one that transfers control from one sub-routine to + another. These operations may be traditional direct calls `call @foo`, or + indirect calls to other operations `call_indirect %foo`. An operation that + uses this interface, must *not* also provide the `CallableOpInterface`. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns the callee of this call-like operation. A `callee` is either a + reference to a symbol, via SymbolRefAttr, or a reference to a defined + SSA value. + }], + "CallInterfaceCallable", "getCallableForCallee" + >, + InterfaceMethod<[{ + Returns the operands within this call that are used as arguments to the + callee. + }], + "Operation::operand_range", "getArgOperands" + >, + ]; +} + +/// Interface for callable operations. +def CallableOpInterface : OpInterface<"CallableOpInterface"> { + let description = [{ + A callable operation is one who represents a potential sub-routine, and may + be a target for a call-like operation (those providing the CallOpInterface + above). These operations may be traditional functional operation + `func @foo(...)`, as well as function producing operations + `%foo = dialect.create_function(...)`. These operations may produce multiple + callable regions, or subroutines. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns a region on the current operation that the given callable refers + to. This may return null in the case of an external callable object, + e.g. an external function. + }], + "Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable) + >, + InterfaceMethod<[{ + Returns all of the callable regions of this operation. + }], + "void", "getCallableRegions", + (ins "SmallVectorImpl &":$callables) + >, + InterfaceMethod<[{ + Returns the results types that the given callable region produces when + executed. + }], + "ArrayRef", "getCallableResults", (ins "Region *":$callable) + >, + ]; +} + +#endif // MLIR_CALLINTERFACES diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h new file mode 100644 index 0000000000000000000000000000000000000000..ead54b93e8084592cb196b1980f1961dbf0b989e --- /dev/null +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -0,0 +1,141 @@ +//===- Dominance.h - Dominator analysis for CFGs ----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DOMINANCE_H +#define MLIR_ANALYSIS_DOMINANCE_H + +#include "mlir/IR/RegionGraphTraits.h" +#include "llvm/Support/GenericDomTree.h" + +extern template class llvm::DominatorTreeBase; +extern template class llvm::DominatorTreeBase; + +namespace mlir { +using DominanceInfoNode = llvm::DomTreeNodeBase; +class Operation; + +namespace detail { +template class DominanceInfoBase { + using base = llvm::DominatorTreeBase; + +public: + DominanceInfoBase(Operation *op) { recalculate(op); } + DominanceInfoBase(DominanceInfoBase &&) = default; + DominanceInfoBase &operator=(DominanceInfoBase &&) = default; + + DominanceInfoBase(const DominanceInfoBase &) = delete; + DominanceInfoBase &operator=(const DominanceInfoBase &) = delete; + + /// Recalculate the dominance info. + void recalculate(Operation *op); + + /// Get the root dominance node of the given region. + DominanceInfoNode *getRootNode(Region *region) { + assert(dominanceInfos.count(region) != 0); + return dominanceInfos[region]->getRootNode(); + } + +protected: + using super = DominanceInfoBase; + + /// Return true if the specified block A properly dominates block B. + bool properlyDominates(Block *a, Block *b); + + /// A mapping of regions to their base dominator tree. + DenseMap> dominanceInfos; +}; +} // end namespace detail + +/// A class for computing basic dominance information. +class DominanceInfo : public detail::DominanceInfoBase { +public: + using super::super; + + /// Return true if operation A properly dominates operation B. + bool properlyDominates(Operation *a, Operation *b); + + /// Return true if operation A dominates operation B. + bool dominates(Operation *a, Operation *b) { + return a == b || properlyDominates(a, b); + } + + /// Return true if value A properly dominates operation B. + bool properlyDominates(Value a, Operation *b); + + /// Return true if operation A dominates operation B. + bool dominates(Value a, Operation *b) { + return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b); + } + + /// Return true if the specified block A dominates block B. + bool dominates(Block *a, Block *b) { + return a == b || properlyDominates(a, b); + } + + /// Return true if the specified block A properly dominates block B. + bool properlyDominates(Block *a, Block *b) { + return super::properlyDominates(a, b); + } + + /// Return the dominance node from the Region containing block A. + DominanceInfoNode *getNode(Block *a); + + /// Update the internal DFS numbers for the dominance nodes. + void updateDFSNumbers(); +}; + +/// A class for computing basic postdominance information. +class PostDominanceInfo : public detail::DominanceInfoBase { +public: + using super::super; + + /// Return true if operation A properly postdominates operation B. + bool properlyPostDominates(Operation *a, Operation *b); + + /// Return true if operation A postdominates operation B. + bool postDominates(Operation *a, Operation *b) { + return a == b || properlyPostDominates(a, b); + } + + /// Return true if the specified block A properly postdominates block B. + bool properlyPostDominates(Block *a, Block *b) { + return super::properlyDominates(a, b); + } + + /// Return true if the specified block A postdominates block B. + bool postDominates(Block *a, Block *b) { + return a == b || properlyPostDominates(a, b); + } +}; + +} // end namespace mlir + +namespace llvm { + +/// DominatorTree GraphTraits specialization so the DominatorTree can be +/// iterated by generic graph iterators. +template <> struct GraphTraits { + using ChildIteratorType = mlir::DominanceInfoNode::iterator; + using NodeRef = mlir::DominanceInfoNode *; + + static NodeRef getEntryNode(NodeRef N) { return N; } + static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); } + static inline ChildIteratorType child_end(NodeRef N) { return N->end(); } +}; + +template <> struct GraphTraits { + using ChildIteratorType = mlir::DominanceInfoNode::const_iterator; + using NodeRef = const mlir::DominanceInfoNode *; + + static NodeRef getEntryNode(NodeRef N) { return N; } + static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); } + static inline ChildIteratorType child_end(NodeRef N) { return N->end(); } +}; + +} // end namespace llvm +#endif diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..baf16162a0be5563f2fddbbcabfda74f0812e055 --- /dev/null +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -0,0 +1,44 @@ +//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the infer op interfaces defined in +// `InferTypeOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ +#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +#include "mlir/Analysis/InferTypeOpInterface.h.inc" + +namespace OpTrait { +template +class TypeOpInterfaceDefault + : public TraitBase { +public: + /// Returns whether two arrays are equal as strongest check for compatibility + /// by default. + static bool isCompatibleReturnTypes(ArrayRef lhs, ArrayRef rhs) { + return lhs == rhs; + }; +}; +} // namespace OpTrait + +} // namespace mlir + +#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td new file mode 100644 index 0000000000000000000000000000000000000000..bbcea6be7eb3fd86256e098a8b77e308c0787dcb --- /dev/null +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -0,0 +1,65 @@ +//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to define information +// related to type inference. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INFERTYPEOPINTERFACE +#define MLIR_INFERTYPEOPINTERFACE + +include "mlir/IR/OpBase.td" + +// OpInterface to compute the return type of an operation. The arguments match +// those in Operation::create with the exception that the location is optional +// (if no location is provided, then the method will not emit an error on +// mismatch). +def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that could be used during op construction, verification or + type inference. + }]; + + let methods = [ + StaticInterfaceMethod< + /*desc=*/[{Infer the return types that an op would generate. + + The method takes an optional location which, if set, will be used to + report errors on. The operands and attributes correspond to those with + which an Operation would be created (e.g., as used in Operation::create) + and the regions of the op. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"inferReturnTypes", + /*args=*/(ins "Optional":$location, + "ValueRange":$operands, + "ArrayRef":$attributes, + "RegionRange":$regions, + "SmallVectorImpl&":$inferedReturnTypes) + >, + StaticInterfaceMethod< + /*desc=*/"Returns whether two array of types are compatible result types" + " for an op.", + /*retTy=*/"bool", + /*methodName=*/"isCompatibleReturnTypes", + /*args=*/(ins "ArrayRef":$lhs, "ArrayRef":$rhs), + /*methodBody=*/[{ + return ConcreteOp::isCompatibleReturnTypes(lhs, rhs); + }], + /*defaultImplementation=*/[{ + /// Returns whether two arrays are equal as strongest check for + /// compatibility by default. + return lhs == rhs; + }] + >, + ]; +} + +#endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h new file mode 100644 index 0000000000000000000000000000000000000000..7e1dc2903ae345bd037fd35a695ccfc1a752f12e --- /dev/null +++ b/mlir/include/mlir/Analysis/Liveness.h @@ -0,0 +1,148 @@ +//===- Liveness.h - Liveness analysis for MLIR ------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an analysis for computing liveness information from a +// given top-level operation. The current version of the analysis uses a +// traditional algorithm to resolve detailed live-range information about all +// values within the specified regions. It is also possible to query liveness +// information on block level. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_LIVENESS_H +#define MLIR_ANALYSIS_LIVENESS_H + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" + +namespace mlir { + +class Block; +class LivenessBlockInfo; +class Operation; +class Region; +class Value; + +/// Represents an analysis for computing liveness information from a +/// given top-level operation. The analysis iterates over all associated +/// regions that are attached to the given top-level operation. It +/// computes liveness information for every value and block that are +/// included in the mentioned regions. It relies on a fixpoint iteration +/// to compute all live-in and live-out values of all included blocks. +/// Sample usage: +/// Liveness liveness(topLevelOp); +/// auto &allInValues = liveness.getLiveIn(block); +/// auto &allOutValues = liveness.getLiveOut(block); +/// auto allOperationsInWhichValueIsLive = liveness.resolveLiveness(value); +/// bool lastUse = liveness.isLastUse(value, operation); +class Liveness { +public: + using OperationListT = std::vector; + using BlockMapT = DenseMap; + using ValueSetT = SmallPtrSet; + +public: + /// Creates a new Liveness analysis that computes liveness + /// information for all associated regions. + Liveness(Operation *op); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Gets liveness info (if any) for the given value. + /// This includes all operations in which the given value is live. + /// Note that the operations in this list are not ordered and the current + /// implementation is computationally expensive (as it iterates over all + /// blocks in which the given value is live). + OperationListT resolveLiveness(Value value) const; + + /// Gets liveness info (if any) for the block. + const LivenessBlockInfo *getLiveness(Block *block) const; + + /// Returns a reference to a set containing live-in values (unordered). + const ValueSetT &getLiveIn(Block *block) const; + + /// Returns a reference to a set containing live-out values (unordered). + const ValueSetT &getLiveOut(Block *block) const; + + /// Returns true if the given operation represent the last use of the + /// given value. + bool isLastUse(Value value, Operation *operation) const; + + /// Dumps the liveness information in a human readable format. + void dump() const; + + /// Dumps the liveness information to the given stream. + void print(raw_ostream &os) const; + +private: + /// Initializes the internal mappings. + void build(MutableArrayRef regions); + +private: + /// The operation this analysis was constructed from. + Operation *operation; + + /// Maps blocks to internal liveness information. + BlockMapT blockMapping; +}; + +/// This class represents liveness information on block level. +class LivenessBlockInfo { +public: + /// A typedef declaration of a value set. + using ValueSetT = Liveness::ValueSetT; + +public: + /// Returns the underlying block. + Block *getBlock() const { return block; } + + /// Returns all values that are live at the beginning + /// of the block (unordered). + const ValueSetT &in() const { return inValues; } + + /// Returns all values that are live at the end + /// of the block (unordered). + const ValueSetT &out() const { return outValues; } + + /// Returns true if the given value is in the live-in set. + bool isLiveIn(Value value) const; + + /// Returns true if the given value is in the live-out set. + bool isLiveOut(Value value) const; + + /// Gets the start operation for the given value. This is the first operation + /// the given value is considered to be live. This could either be the start + /// operation of the current block (in case the value is live-in) or the + /// operation that defines the given value (must be referenced in this block). + Operation *getStartOperation(Value value) const; + + /// Gets the end operation for the given value using the start operation + /// provided (must be referenced in this block). + Operation *getEndOperation(Value value, Operation *startOperation) const; + +private: + /// The underlying block. + Block *block; + + /// The set of all live in values. + ValueSetT inValues; + + /// The set of all live out values. + ValueSetT outValues; + + friend class Liveness; +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_LIVENESS_H diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd89e454a83a1673f0bc1adcececce4f16b1950 --- /dev/null +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -0,0 +1,88 @@ +//===- LoopAnalysis.h - loop analysis methods -------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for methods to analyze loops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_LOOP_ANALYSIS_H +#define MLIR_ANALYSIS_LOOP_ANALYSIS_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { + +class AffineExpr; +class AffineForOp; +class AffineMap; +class MemRefType; +class NestedPattern; +class Operation; +class Value; + +/// Returns the trip count of the loop as an affine map with its corresponding +/// operands if the latter is expressible as an affine expression, and nullptr +/// otherwise. This method always succeeds as long as the lower bound is not a +/// multi-result map. The trip count expression is simplified before returning. +/// This method only utilizes map composition to construct lower and upper +/// bounds before computing the trip count expressions +// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a +// pure analysis method relying on FlatAffineConstraints +void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map, + SmallVectorImpl *operands); + +/// Returns the trip count of the loop if it's a constant, None otherwise. This +/// uses affine expression analysis and is able to determine constant trip count +/// in non-trivial cases. +Optional getConstantTripCount(AffineForOp forOp); + +/// Returns the greatest known integral divisor of the trip count. Affine +/// expression analysis is used (indirectly through getTripCount), and +/// this method is thus able to determine non-trivial divisors. +uint64_t getLargestDivisorOfTripCount(AffineForOp forOp); + +/// Given an induction variable `iv` of type AffineForOp and `indices` of type +/// IndexType, returns the set of `indices` that are independent of `iv`. +/// +/// Prerequisites (inherited from `isAccessInvariant` above): +/// 1. `iv` and `indices` of the proper type; +/// 2. at most one affine.apply is reachable from each index in `indices`; +/// +/// Emits a note if it encounters a chain of affine.apply and conservatively +/// those cases. +DenseSet> +getInvariantAccesses(Value iv, ArrayRef indices); + +using VectorizableLoopFun = std::function; + +/// Checks whether the loop is structurally vectorizable; i.e.: +/// 1. no conditionals are nested under the loop; +/// 2. all nested load/stores are to scalar MemRefs. +/// TODO(ntv): relax the no-conditionals restriction +bool isVectorizableLoopBody(AffineForOp loop, + NestedPattern &vectorTransferMatcher); + +/// Checks whether the loop is structurally vectorizable and that all the LoadOp +/// and StoreOp matched have access indexing functions that are are either: +/// 1. invariant along the loop induction variable created by 'loop'; +/// 2. varying along at most one memory dimension. If such a unique dimension +/// is found, it is written into `memRefDim`. +bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim, + NestedPattern &vectorTransferMatcher); + +/// Checks where SSA dominance would be violated if a for op's body +/// operations are shifted by the specified shifts. This method checks if a +/// 'def' and all its uses have the same shift factor. +// TODO(mlir-team): extend this to check for memory-based dependence +// violation when we have the support. +bool isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts); +} // end namespace mlir + +#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h new file mode 100644 index 0000000000000000000000000000000000000000..2da64e88e14876f798acf9adc94d7d95df64ca05 --- /dev/null +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -0,0 +1,187 @@ +//===- NestedMacher.h - Nested matcher for Function -------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ +#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ + +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "llvm/Support/Allocator.h" + +namespace mlir { + +class NestedPattern; +class Operation; + +/// An NestedPattern captures nested patterns in the IR. +/// It is used in conjunction with a scoped NestedPatternContext which is an +/// llvm::BumpPtrAllocator that handles memory allocations efficiently and +/// avoids ownership issues. +/// +/// In order to use NestedPatterns, first create a scoped context. +/// When the context goes out of scope, everything is freed. +/// This design simplifies the API by avoiding references to the context and +/// makes it clear that references to matchers must not escape. +/// +/// Example: +/// { +/// NestedPatternContext context; +/// auto gemmLike = Doall(Doall(Red(LoadStores()))); +/// auto matches = gemmLike.match(f); +/// // do work on matches +/// } // everything is freed +/// +/// +/// Nested abstraction for matching results. +/// Provides access to the nested Operation* captured by a Matcher. +/// +/// A NestedMatch contains an Operation* and the children NestedMatch and is +/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose +/// lifetime is managed by an RAII NestedPatternContext. +class NestedMatch { +public: + static NestedMatch build(Operation *operation, + ArrayRef nestedMatches); + NestedMatch(const NestedMatch &) = default; + NestedMatch &operator=(const NestedMatch &) = default; + + explicit operator bool() { return matchedOperation != nullptr; } + + Operation *getMatchedOperation() { return matchedOperation; } + ArrayRef getMatchedChildren() { return matchedChildren; } + +private: + friend class NestedPattern; + friend class NestedPatternContext; + + /// Underlying global bump allocator managed by a NestedPatternContext. + static llvm::BumpPtrAllocator *&allocator(); + + NestedMatch() = default; + + /// Payload, holds a NestedMatch and all its children along this branch. + Operation *matchedOperation; + ArrayRef matchedChildren; +}; + +/// A NestedPattern is a nested operation walker that: +/// 1. recursively matches a substructure in the tree; +/// 2. uses a filter function to refine matches with extra semantic +/// constraints (passed via a lambda of type FilterFunctionType); +/// 3. TODO(ntv) optionally applies actions (lambda). +/// +/// Nested patterns are meant to capture imperfectly nested loops while matching +/// properties over the whole loop nest. For instance, in vectorization we are +/// interested in capturing all the imperfectly nested loops of a certain type +/// and such that all the load and stores have certain access patterns along the +/// loops' induction variables). Such NestedMatches are first captured using the +/// `match` function and are later processed to analyze properties and apply +/// transformations in a non-greedy way. +/// +/// The NestedMatches captured in the IR can grow large, especially after +/// aggressive unrolling. As experience has shown, it is generally better to use +/// a plain walk over operations to match flat patterns but the current +/// implementation is competitive nonetheless. +using FilterFunctionType = std::function; +inline bool defaultFilterFunction(Operation &) { return true; } +class NestedPattern { +public: + NestedPattern(ArrayRef nested, + FilterFunctionType filter = defaultFilterFunction); + NestedPattern(const NestedPattern &) = default; + NestedPattern &operator=(const NestedPattern &) = default; + + /// Returns all the top-level matches in `func`. + void match(FuncOp func, SmallVectorImpl *matches) { + func.walk([&](Operation *op) { matchOne(op, matches); }); + } + + /// Returns all the top-level matches in `op`. + void match(Operation *op, SmallVectorImpl *matches) { + op->walk([&](Operation *child) { matchOne(child, matches); }); + } + + /// Returns the depth of the pattern. + unsigned getDepth() const; + +private: + friend class NestedPatternContext; + friend class NestedMatch; + friend struct State; + + /// Underlying global bump allocator managed by a NestedPatternContext. + static llvm::BumpPtrAllocator *&allocator(); + + /// Matches this pattern against a single `op` and fills matches with the + /// result. + void matchOne(Operation *op, SmallVectorImpl *matches); + + /// Nested patterns to be matched. + ArrayRef nestedPatterns; + + /// Extra filter function to apply to prune patterns as the IR is walked. + FilterFunctionType filter; + + /// skip is an implementation detail needed so that we can implement match + /// without switching on the type of the Operation. The idea is that a + /// NestedPattern first checks if it matches locally and then recursively + /// applies its nested matchers to its elem->nested. Since we want to rely on + /// the existing operation walking functionality rather than duplicate + /// it, we allow an off-by-one traversal to account for the fact that we + /// write: + /// + /// void match(Operation *elem) { + /// for (auto &c : getNestedPatterns()) { + /// NestedPattern childPattern(...); + /// ^~~~ Needs off-by-one skip. + /// + Operation *skip; +}; + +/// RAII structure to transparently manage the bump allocator for +/// NestedPattern and NestedMatch classes. This avoids passing a context to +/// all the API functions. +class NestedPatternContext { +public: + NestedPatternContext() { + assert(NestedMatch::allocator() == nullptr && + "Only a single NestedPatternContext is supported"); + assert(NestedPattern::allocator() == nullptr && + "Only a single NestedPatternContext is supported"); + NestedMatch::allocator() = &allocator; + NestedPattern::allocator() = &allocator; + } + ~NestedPatternContext() { + NestedMatch::allocator() = nullptr; + NestedPattern::allocator() = nullptr; + } + llvm::BumpPtrAllocator allocator; +}; + +namespace matcher { +// Syntactic sugar NestedPattern builder functions. +NestedPattern Op(FilterFunctionType filter = defaultFilterFunction); +NestedPattern If(NestedPattern child); +NestedPattern If(FilterFunctionType filter, NestedPattern child); +NestedPattern If(ArrayRef nested = {}); +NestedPattern If(FilterFunctionType filter, + ArrayRef nested = {}); +NestedPattern For(NestedPattern child); +NestedPattern For(FilterFunctionType filter, NestedPattern child); +NestedPattern For(ArrayRef nested = {}); +NestedPattern For(FilterFunctionType filter, + ArrayRef nested = {}); + +bool isParallelLoop(Operation &op); +bool isReductionLoop(Operation &op); +bool isLoadOrStore(Operation &op); + +} // end namespace matcher +} // end namespace mlir + +#endif // MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..0bbc850e6c9b523e384972d9022e4cca0e26b0ad --- /dev/null +++ b/mlir/include/mlir/Analysis/Passes.h @@ -0,0 +1,36 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors in the +// analysis library. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PASSES_H +#define MLIR_ANALYSIS_PASSES_H + +#include "mlir/Support/LLVM.h" +#include + +namespace mlir { + +class FuncOp; +template class OpPassBase; + +/// Creates a pass to check memref accesses in a Function. +std::unique_ptr> createMemRefBoundCheckPass(); + +/// Creates a pass to check memref access dependences in a Function. +std::unique_ptr> createTestMemRefDependenceCheckPass(); + +/// Creates a pass to test parallelism detection; emits note for parallel loops. +std::unique_ptr> createParallelismDetectionTestPass(); + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_PASSES_H diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h new file mode 100644 index 0000000000000000000000000000000000000000..d7b6e9570142dd051b19f8aa815b9e76e62cf39c --- /dev/null +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -0,0 +1,206 @@ +//===- SliceAnalysis.h - Analysis for Transitive UseDef chains --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_SLICEANALYSIS_H_ +#define MLIR_ANALYSIS_SLICEANALYSIS_H_ + +#include +#include + +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/SetVector.h" + +namespace mlir { + +class Operation; + +/// Type of the condition to limit the propagation of transitive use-defs. +/// This can be used in particular to limit the propagation to a given Scope or +/// to avoid passing through certain types of operation in a configurable +/// manner. +using TransitiveFilter = std::function; + +/// Fills `forwardSlice` with the computed forward slice (i.e. all +/// the transitive uses of op), **without** including that operation. +/// +/// This additionally takes a TransitiveFilter which acts as a frontier: +/// when looking at uses transitively, a operation that does not pass the +/// filter is never propagated through. This allows in particular to carve out +/// the scope within a ForInst or the scope within an IfInst. +/// +/// The implementation traverses the use chains in postorder traversal for +/// efficiency reasons: if a operation is already in `forwardSlice`, no +/// need to traverse its uses again. Since use-def chains form a DAG, this +/// terminates. +/// +/// Upon return to the root call, `forwardSlice` is filled with a +/// postorder list of uses (i.e. a reverse topological order). To get a proper +/// topological order, we just just reverse the order in `forwardSlice` before +/// returning. +/// +/// Example starting from node 0 +/// ============================ +/// +/// 0 +/// ___________|___________ +/// 1 2 3 4 +/// |_______| |______| +/// | | | +/// | 5 6 +/// |___|_____________| +/// | | +/// 7 8 +/// |_______________| +/// | +/// 9 +/// +/// Assuming all local orders match the numbering order: +/// 1. after getting back to the root getForwardSlice, `forwardSlice` may +/// contain: +/// {9, 7, 8, 5, 1, 2, 6, 3, 4} +/// 2. reversing the result of 1. gives: +/// {4, 3, 6, 2, 1, 5, 8, 7, 9} +/// +void getForwardSlice( + Operation *op, llvm::SetVector *forwardSlice, + TransitiveFilter filter = /* pass-through*/ + [](Operation *) { return true; }); + +/// Fills `backwardSlice` with the computed backward slice (i.e. +/// all the transitive defs of op), **without** including that operation. +/// +/// This additionally takes a TransitiveFilter which acts as a frontier: +/// when looking at defs transitively, a operation that does not pass the +/// filter is never propagated through. This allows in particular to carve out +/// the scope within a ForInst or the scope within an IfInst. +/// +/// The implementation traverses the def chains in postorder traversal for +/// efficiency reasons: if a operation is already in `backwardSlice`, no +/// need to traverse its definitions again. Since useuse-def chains form a DAG, +/// this terminates. +/// +/// Upon return to the root call, `backwardSlice` is filled with a +/// postorder list of defs. This happens to be a topological order, from the +/// point of view of the use-def chains. +/// +/// Example starting from node 8 +/// ============================ +/// +/// 1 2 3 4 +/// |_______| |______| +/// | | | +/// | 5 6 +/// |___|_____________| +/// | | +/// 7 8 +/// |_______________| +/// | +/// 9 +/// +/// Assuming all local orders match the numbering order: +/// {1, 2, 5, 3, 4, 6} +/// +void getBackwardSlice( + Operation *op, llvm::SetVector *backwardSlice, + TransitiveFilter filter = /* pass-through*/ + [](Operation *) { return true; }); + +/// Iteratively computes backward slices and forward slices until +/// a fixed point is reached. Returns an `llvm::SetVector` which +/// **includes** the original operation. +/// +/// This allows building a slice (i.e. multi-root DAG where everything +/// that is reachable from an Value in forward and backward direction is +/// contained in the slice). +/// This is the abstraction we need to materialize all the operations for +/// supervectorization without worrying about orderings and Value +/// replacements. +/// +/// Example starting from any node +/// ============================== +/// +/// 1 2 3 4 +/// |_______| |______| +/// | | | | +/// | 5 6___| +/// |___|_____________| | +/// | | | +/// 7 8 | +/// |_______________| | +/// | | +/// 9 10 +/// +/// Return the whole DAG in some topological order. +/// +/// The implementation works by just filling up a worklist with iterative +/// alternate calls to `getBackwardSlice` and `getForwardSlice`. +/// +/// The following section describes some additional implementation +/// considerations for a potentially more efficient implementation but they are +/// just an intuition without proof, we still use a worklist for now. +/// +/// Additional implementation considerations +/// ======================================== +/// Consider the defs-op-uses hourglass. +/// ____ +/// \ / defs (in some topological order) +/// \/ +/// op +/// /\ +/// / \ uses (in some topological order) +/// /____\ +/// +/// We want to iteratively apply `getSlice` to construct the whole +/// list of Operation that are reachable by (use|def)+ from op. +/// We want the resulting slice in topological order. +/// Ideally we would like the ordering to be maintained in-place to avoid +/// copying Operation at each step. Keeping this ordering by construction +/// seems very unclear, so we list invariants in the hope of seeing whether +/// useful properties pop up. +/// +/// In the following: +/// we use |= for set inclusion; +/// we use << for set topological ordering (i.e. each pair is ordered). +/// +/// Assumption: +/// =========== +/// We wish to maintain the following property by a recursive argument: +/// """ +/// defs << {op} < getSlice( + Operation *op, + TransitiveFilter backwardFilter = /* pass-through*/ + [](Operation *) { return true; }, + TransitiveFilter forwardFilter = /* pass-through*/ + [](Operation *) { return true; }); + +/// Multi-root DAG topological sort. +/// Performs a topological sort of the Operation in the `toSort` SetVector. +/// Returns a topologically sorted SetVector. +llvm::SetVector +topologicalSort(const llvm::SetVector &toSort); + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_SLICEANALYSIS_H_ diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7cf1e5c971acc49c76150311754c75b7b28ebaee --- /dev/null +++ b/mlir/include/mlir/Analysis/Utils.h @@ -0,0 +1,295 @@ +//===- Utils.h - General analysis utilities ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various transformation utilities for +// memref's and non-loop IR structures. These are not passes by themselves but +// are used either by passes, optimization sequences, or in turn by other +// transformation utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_UTILS_H +#define MLIR_ANALYSIS_UTILS_H + +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir { + +class AffineForOp; +class Block; +class FlatAffineConstraints; +class Location; +struct MemRefAccess; +class Operation; +class Value; + +/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from +/// the outermost 'affine.for' operation to the innermost one. +// TODO(bondhugula): handle 'affine.if' ops. +void getLoopIVs(Operation &op, SmallVectorImpl *loops); + +/// Returns the nesting depth of this operation, i.e., the number of loops +/// surrounding this operation. +unsigned getNestingDepth(Operation &op); + +/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted +/// at 'forOp'. +void getSequentialLoops(AffineForOp forOp, + llvm::SmallDenseSet *sequentialLoops); + +/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their +/// associated operands for a set of loops within a loop nest (typically the +/// set of loops surrounding a store operation). Loop bound AffineMaps which +/// are non-null represent slices of that loop's iteration space. +struct ComputationSliceState { + // List of sliced loop IVs (ordered from outermost to innermost). + // EX: 'ivs[i]' has lower bound 'lbs[i]' and upper bound 'ubs[i]'. + SmallVector ivs; + // List of lower bound AffineMaps. + SmallVector lbs; + // List of upper bound AffineMaps. + SmallVector ubs; + // List of lower bound operands (lbOperands[i] are used by 'lbs[i]'). + std::vector> lbOperands; + // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). + std::vector> ubOperands; + // Slice loop nest insertion point in target loop nest. + Block::iterator insertPoint; + // Adds to 'cst' with constraints which represent the slice bounds on 'ivs' + // in 'this'. Specifically, the values in 'ivs' are added to 'cst' as dim + // identifiers and the values in 'lb/ubOperands' are added as symbols. + // Constraints are added for all loop IV bounds (dim or symbol), and + // constraints are added for slice bounds in 'lbs'/'ubs'. + // Returns failure if we cannot add loop bounds because of unsupported cases. + LogicalResult getAsConstraints(FlatAffineConstraints *cst); + + // Clears all bounds and operands in slice state. + void clearBounds(); +}; + +/// Computes the computation slice loop bounds for one loop nest as affine maps +/// of the other loop nest's IVs and symbols, using 'dependenceConstraints' +/// computed between 'depSourceAccess' and 'depSinkAccess'. +/// If 'isBackwardSlice' is true, a backwards slice is computed in which the +/// slice bounds of loop nest surrounding 'depSourceAccess' are computed in +/// terms of loop IVs and symbols of the loop nest surrounding 'depSinkAccess' +/// at 'loopDepth'. +/// If 'isBackwardSlice' is false, a forward slice is computed in which the +/// slice bounds of loop nest surrounding 'depSinkAccess' are computed in terms +/// of loop IVs and symbols of the loop nest surrounding 'depSourceAccess' at +/// 'loopDepth'. +/// The slice loop bounds and associated operands are returned in 'sliceState'. +// +// Backward slice example: +// +// affine.for %i0 = 0 to 10 { +// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// } +// affine.for %i1 = 0 to 10 { +// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// } +// +// // Backward computation slice of loop nest '%i0'. +// affine.for %i0 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 1)(%i1) { +// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// } +// +// Forward slice example: +// +// affine.for %i0 = 0 to 10 { +// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// } +// affine.for %i1 = 0 to 10 { +// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// } +// +// // Forward computation slice of loop nest '%i1'. +// affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) { +// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// } +// +void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, + FlatAffineConstraints *dependenceConstraints, + unsigned loopDepth, bool isBackwardSlice, + ComputationSliceState *sliceState); + +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. +/// The parameter 'numCommonLoops' is the number of loops common to the +/// operations in 'opsA' and 'opsB'. +/// If 'isBackwardSlice' is true, computes slice bounds for loop nest +/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest +/// surrounding ops in 'opsB' at 'loopDepth'. +/// If 'isBackwardSlice' is false, computes slice bounds for loop nest +/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest +/// surrounding ops in 'opsA' at 'loopDepth'. +/// Returns 'success' if union was computed, 'failure' otherwise. +// TODO(andydavis) Change this API to take 'forOpA'/'forOpB'. +LogicalResult computeSliceUnion(ArrayRef opsA, + ArrayRef opsB, unsigned loopDepth, + unsigned numCommonLoops, bool isBackwardSlice, + ComputationSliceState *sliceUnion); + +/// Creates a clone of the computation contained in the loop nest surrounding +/// 'srcOpInst', slices the iteration space of src loop based on slice bounds +/// in 'sliceState', and inserts the computation slice at the beginning of the +/// operation block of the loop at 'dstLoopDepth' in the loop nest surrounding +/// 'dstOpInst'. Returns the top-level loop of the computation slice on +/// success, returns nullptr otherwise. +// Loop depth is a crucial optimization choice that determines where to +// materialize the results of the backward slice - presenting a trade-off b/w +// storage and redundant computation in several cases. +// TODO(andydavis) Support computation slices with common surrounding loops. +AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, + Operation *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); + +/// A region of a memref's data space; this is typically constructed by +/// analyzing load/store op's on this memref and the index space of loops +/// surrounding such op's. +// For example, the memref region for a load operation at loop depth = 1: +// +// affine.for %i = 0 to 32 { +// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { +// affine.load %A[%ii] +// } +// } +// +// Region: {memref = %A, write = false, {%i <= m0 <= %i + 7} } +// The last field is a 2-d FlatAffineConstraints symbolic in %i. +// +struct MemRefRegion { + explicit MemRefRegion(Location loc) : loc(loc) {} + + /// Computes the memory region accessed by this memref with the region + /// represented as constraints symbolic/parametric in 'loopDepth' loops + /// surrounding opInst. The computed region's 'cst' field has exactly as many + /// dimensional identifiers as the rank of the memref, and *potentially* + /// additional symbolic identifiers which could include any of the loop IVs + /// surrounding opInst up until 'loopDepth' and another additional Function + /// symbols involved with the access (for eg., those appear in affine.apply's, + /// loop bounds, etc.). If 'sliceState' is non-null, operands from + /// 'sliceState' are added as symbols, and the following constraints are added + /// to the system: + /// *) Inequality constraints which represent loop bounds for 'sliceState' + /// operands which are loop IVS (these represent the destination loop IVs + /// of the slice, and are added as symbols to MemRefRegion's constraint + /// system). + /// *) Inequality constraints for the slice bounds in 'sliceState', which + /// represent the bounds on the loop IVs in this constraint system w.r.t + /// to slice operands (which correspond to symbols). + /// If 'addMemRefDimBounds' is true, constant upper/lower bounds + /// [0, memref.getDimSize(i)) are added for each MemRef dimension 'i'. + /// + /// For example, the memref region for this operation at loopDepth = 1 will + /// be: + /// + /// affine.for %i = 0 to 32 { + /// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { + /// load %A[%ii] + /// } + /// } + /// + /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } + /// The last field is a 2-d FlatAffineConstraints symbolic in %i. + /// + LogicalResult compute(Operation *op, unsigned loopDepth, + ComputationSliceState *sliceState = nullptr, + bool addMemRefDimBounds = true); + + FlatAffineConstraints *getConstraints() { return &cst; } + const FlatAffineConstraints *getConstraints() const { return &cst; } + bool isWrite() const { return write; } + void setWrite(bool flag) { write = flag; } + + /// Returns a constant upper bound on the number of elements in this region if + /// bounded by a known constant (always possible for static shapes), None + /// otherwise. Note that the symbols of the region are treated specially, + /// i.e., the returned bounding constant holds for *any given* value of the + /// symbol identifiers. The 'shape' vector is set to the corresponding + /// dimension-wise bounds major to minor. We use int64_t instead of uint64_t + /// since index types can be at most int64_t. + Optional getConstantBoundingSizeAndShape( + SmallVectorImpl *shape = nullptr, + std::vector> *lbs = nullptr, + SmallVectorImpl *lbDivisors = nullptr) const; + + /// A wrapper around FlatAffineConstraints::getConstantBoundOnDimSize(). 'pos' + /// corresponds to the position of the memref shape's dimension (major to + /// minor) which matches 1:1 with the dimensional identifier positions in + //'cst'. + Optional + getConstantBoundOnDimSize(unsigned pos, + SmallVectorImpl *lb = nullptr, + int64_t *lbFloorDivisor = nullptr) const { + assert(pos < getRank() && "invalid position"); + return cst.getConstantBoundOnDimSize(pos, lb); + } + + /// Returns the size of this MemRefRegion in bytes. + Optional getRegionSize(); + + // Wrapper around FlatAffineConstraints::unionBoundingBox. + LogicalResult unionBoundingBox(const MemRefRegion &other); + + /// Returns the rank of the memref that this region corresponds to. + unsigned getRank() const; + + /// Memref that this region corresponds to. + Value memref; + + /// Read or write. + bool write; + + /// If there is more than one load/store op associated with the region, the + /// location information would correspond to one of those op's. + Location loc; + + /// Region (data space) of the memref accessed. This set will thus have at + /// least as many dimensional identifiers as the shape dimensionality of the + /// memref, and these are the leading dimensions of the set appearing in that + /// order (major to minor / outermost to innermost). There may be additional + /// identifiers since getMemRefRegion() is called with a specific loop depth, + /// and thus the region is symbolic in the outer surrounding loops at that + /// depth. + // TODO(bondhugula): Replace this to exploit HyperRectangularSet. + FlatAffineConstraints cst; +}; + +/// Returns the size of memref data in bytes if it's statically shaped, None +/// otherwise. +Optional getMemRefSizeInBytes(MemRefType memRefType); + +/// Checks a load or store op for an out of bound access; returns failure if the +/// access is out of bounds along any of the dimensions, success otherwise. +/// Emits a diagnostic error (with location information) if emitError is true. +template +LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, + bool emitError = true); + +/// Returns the number of surrounding loops common to both A and B. +unsigned getNumCommonSurroundingLoops(Operation &A, Operation &B); + +/// Gets the memory footprint of all data touched in the specified memory space +/// in bytes; if the memory space is unspecified, considers all memory spaces. +Optional getMemoryFootprintBytes(AffineForOp forOp, + int memorySpace = -1); + +/// Returns true if `forOp' is a parallel loop. +bool isLoopParallel(AffineForOp forOp); + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/include/mlir/Analysis/Verifier.h b/mlir/include/mlir/Analysis/Verifier.h new file mode 100644 index 0000000000000000000000000000000000000000..b7075b4f1578f1614dc6e0196bd0dab860f05fa9 --- /dev/null +++ b/mlir/include/mlir/Analysis/Verifier.h @@ -0,0 +1,22 @@ +//===- Verifier.h - Verifier analysis for MLIR structures -------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_VERIFIER_H +#define MLIR_ANALYSIS_VERIFIER_H + +namespace mlir { +struct LogicalResult; +class Operation; + +/// Perform (potentially expensive) checks of invariants, used to detect +/// compiler bugs, on this operation and any nested operations. On error, this +/// reports the error through the MLIRContext and returns failure. +LogicalResult verify(Operation *op); +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..84031a5e72a744ed8f58eab2d529f790105a7f88 --- /dev/null +++ b/mlir/include/mlir/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h new file mode 100644 index 0000000000000000000000000000000000000000..c6a2fac6ec9fa9458821143375fd7b0d4dd3c972 --- /dev/null +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -0,0 +1,47 @@ +//===- AffineToStandard.h - Convert Affine to Standard dialect --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H +#define MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H + +#include "mlir/Support/LLVM.h" + +namespace mlir { +class AffineExpr; +class AffineForOp; +class Location; +struct LogicalResult; +class MLIRContext; +class OpBuilder; +class RewritePattern; +class Value; + +// Owning list of rewriting patterns. +class OwningRewritePatternList; + +/// Emit code that computes the given affine expression using standard +/// arithmetic operations applied to the provided dimension and symbol values. +Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, + ArrayRef dimValues, ArrayRef symbolValues); + +/// Collect a set of patterns to convert from the Affine dialect to the Standard +/// dialect, in particular convert structured affine control flow into CFG +/// branch-based control flow. +void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +/// Emit code that computes the lower bound of the given affine loop using +/// standard arithmetic operations. +Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder); + +/// Emit code that computes the upper bound of the given affine loop using +/// standard arithmetic operations. +Value lowerAffineUpperBound(AffineForOp op, OpBuilder &builder); +} // namespace mlir + +#endif // MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h new file mode 100644 index 0000000000000000000000000000000000000000..4eb6379adf6e7b588dc902dc110722bf8016c120 --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -0,0 +1,55 @@ +//===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ +#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ + +#include "mlir/Support/LLVM.h" +#include +#include +#include +#include + +namespace mlir { + +class Location; +class ModuleOp; + +namespace LLVM { +class LLVMDialect; +} // namespace LLVM + +template class OpPassBase; + +using OwnedCubin = std::unique_ptr>; +using CubinGenerator = + std::function; + +/// Creates a pass to convert kernel functions into CUBIN blobs. +/// +/// This transformation takes the body of each function that is annotated with +/// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the +/// module with help of the nvptx backend to PTX and then invokes the provided +/// cubinGenerator to produce a binary blob (the cubin). Such blob is then +/// attached as a string attribute named 'nvvm.cubin' to the kernel function. +/// After the transformation, the body of the kernel function is removed (i.e., +/// it is turned into a declaration). +std::unique_ptr> +createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); + +/// Creates a pass to convert a gpu.launch_func operation into a sequence of +/// CUDA calls. +/// +/// This pass does not generate code to call CUDA directly but instead uses a +/// small wrapper library that exports a stable and conveniently typed ABI +/// on top of CUDA. +std::unique_ptr> +createConvertGpuLaunchFuncToCudaCallsPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h new file mode 100644 index 0000000000000000000000000000000000000000..75e4f7e374c6ae00971ab6349fbb36ba31aa6032 --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -0,0 +1,29 @@ +//===- GPUToNVVMPass.h - Convert GPU kernel to NVVM dialect -----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ +#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ + +#include + +namespace mlir { +class LLVMTypeConverter; +class OwningRewritePatternList; + +class ModuleOp; +template class OpPassBase; + +/// Collect a set of patterns to convert from the GPU dialect to NVVM. +void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. +std::unique_ptr> createLowerGpuOpsToNVVMOpsPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h new file mode 100644 index 0000000000000000000000000000000000000000..e913c2e1131a1712552991fdfe63dea6f8cfe863 --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -0,0 +1,23 @@ +//===- GPUToROCDLPass.h - Convert GPU kernel to ROCDL dialect ---*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ +#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ + +#include + +namespace mlir { + +class ModuleOp; +template class OpPassBase; + +/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. +std::unique_ptr> createLowerGpuOpsToROCDLOpsPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h new file mode 100644 index 0000000000000000000000000000000000000000..762a6e502d4e8d00e339139f1b62005391725e82 --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h @@ -0,0 +1,29 @@ +//===- ConvertGPUToSPIRV.h - GPU Ops to SPIR-V dialect patterns ----C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns for lowering GPU Ops to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H +#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; +/// Appends to a pattern list additional patterns for translating GPU Ops to +/// SPIR-V ops. Needs the workgroup size as input since SPIR-V/Vulkan requires +/// the workgroup size to be statically specified. +void populateGPUToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns, + ArrayRef workGroupSize); +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h new file mode 100644 index 0000000000000000000000000000000000000000..37230f4c0e11f3f455471d3befab2e7afc3e4faf --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h @@ -0,0 +1,31 @@ +//===- ConvertGPUToSPIRVPass.h - GPU to SPIR-V conversion pass --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a pass to convert GPU ops to SPIRV ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H +#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H + +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OpPassBase; + +/// Pass to convert GPU Ops to SPIR-V ops. Needs the workgroup size as input +/// since SPIR-V/Vulkan requires the workgroup size to be statically specified. +std::unique_ptr> +createConvertGPUToSPIRVPass(ArrayRef workGroupSize); + +} // namespace mlir +#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h new file mode 100644 index 0000000000000000000000000000000000000000..27950177c1d9cd096cc0eab3f45feccece73e184 --- /dev/null +++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h @@ -0,0 +1,30 @@ +//===- LinalgToLLVM.h - Utils to convert from the linalg dialect ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_ +#define MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class MLIRContext; + +class LinalgTypeConverter : public LLVMTypeConverter { +public: + using LLVMTypeConverter::LLVMTypeConverter; + Type convertType(Type t) override; +}; + +/// Populate the given list with patterns that convert from Linalg to LLVM. +void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, + OwningRewritePatternList &patterns, + MLIRContext *ctx); + +} // namespace mlir + +#endif // MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h b/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h new file mode 100644 index 0000000000000000000000000000000000000000..5cb8f59e6f7eb5026d1301a5cfe002e005fe4f00 --- /dev/null +++ b/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h @@ -0,0 +1,35 @@ +//===- ConvertLoopToStandard.h - Pass entrypoint ----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_ +#define MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_ + +#include +#include + +namespace mlir { +struct LogicalResult; +class MLIRContext; +class Pass; +class RewritePattern; + +// Owning list of rewriting patterns. +class OwningRewritePatternList; + +/// Collect a set of patterns to lower from loop.for, loop.if, and +/// loop.terminator to CFG operations within the Standard dialect, in particular +/// convert structured control flow into CFG branch-based control flow. +void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG. +std::unique_ptr createLowerToCFGPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_ diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h new file mode 100644 index 0000000000000000000000000000000000000000..80faa03f31332dc4afa7af35436407e789246f6d --- /dev/null +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -0,0 +1,77 @@ +//===- LoopsToGPU.h - Convert loop nests to GPU kernels ---------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ +#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ + +#include "mlir/Support/LLVM.h" + +namespace mlir { +class AffineForOp; +struct LogicalResult; +class Value; + +namespace loop { +class ForOp; +} // end namespace loop + +/// Convert a perfect affine loop nest with the outermost loop identified by +/// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to +/// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that +/// are mapped should be independent of the induction variables of the other +/// mapped loops. +/// +/// No check on the size of the block or grid, or on the validity of +/// parallelization is performed, it is under the responsibility of the caller +/// to strip-mine the loops and to perform the dependence analysis before +/// calling the conversion. +LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims); + +/// Convert a perfect linalg loop nest with the outermost loop identified by +/// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to +/// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that +/// are mapped should be independent of the induction variables of the other +/// mapped loops. +/// +/// No check on the size of the block or grid, or on the validity of +/// parallelization is performed, it is under the responsibility of the caller +/// to strip-mine the loops and to perform the dependence analysis before +/// calling the conversion. +LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims); + +/// Convert a loop operation into a GPU launch with the values provided in +/// `numWorkGroups` as the grid size and the values provided in `workGroupSizes` +/// as the block size. Size of `numWorkGroups` and workGroupSizes` must be less +/// than or equal to 3. The loop operation can be an imperfectly nested +/// computation with the following restrictions: +/// 1) The loop nest must contain as many perfectly nested loops as the number +/// of values passed in through `numWorkGroups`. This corresponds to the number +/// of grid dimensions of the launch. All loops within the loop nest must be +/// parallel. +/// 2) The body of the innermost loop of the above perfectly nested loops, must +/// contain statements that satisfy one of the two conditions below: +/// a) A perfect loop nest of depth greater than or equal to the number of +/// values passed in through `workGroupSizes`, i.e. the number of thread +/// dimensions of the launch. Loops at depth less than or equal to size of +/// `workGroupSizes` must be parallel. Loops nested deeper can be sequential +/// and are retained as such in the generated GPU launch code. +/// b) Statements that are safe to be executed by all threads within the +/// workgroup. No checks are performed that this is indeed the case. +/// TODO(ravishankarm) : Add checks that verify 2(b) above. +/// The above conditions are assumed to be satisfied by the computation rooted +/// at `forOp`. +LogicalResult convertLoopToGPULaunch(loop::ForOp forOp, + ArrayRef numWorkGroups, + ArrayRef workGroupSizes); + +} // namespace mlir + +#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h new file mode 100644 index 0000000000000000000000000000000000000000..a3d663ae3d75e29cbb721fb569a1abad8d331678 --- /dev/null +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -0,0 +1,41 @@ +//===- LoopsToGPUPass.h - Pass converting loops to GPU kernels --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ +#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ + +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir { +class FuncOp; +template class OpPassBase; + +/// Create a pass that converts loop nests into GPU kernels. It considers +/// top-level affine.for and linalg.for operations as roots of loop nests and +/// converts them to the gpu.launch operations if possible. +/// +/// No check on the size of the block or grid, or on the validity of +/// parallelization is performed, it is under the responsibility of the caller +/// to strip-mine the loops and to perform the dependence analysis before +/// calling the conversion. +std::unique_ptr> +createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims); + +/// Create a pass that converts every loop operation within the body of the +/// FuncOp into a GPU launch. The number of workgroups and workgroup size for +/// the implementation is controlled by SSA values passed into conversion +/// method. For testing, the values are set as constants obtained from a command +/// line flag. See convertLoopToGPULaunch for a description of the required +/// semantics of the converted loop operation. +std::unique_ptr> +createLoopToGPUPass(ArrayRef numWorkGroups, + ArrayRef workGroupSize); +} // namespace mlir + +#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h new file mode 100644 index 0000000000000000000000000000000000000000..e78859f992bac930647e9d9d832939861917b1ec --- /dev/null +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -0,0 +1,244 @@ +//===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a dialect conversion targeting the LLVM IR dialect. By default, it +// converts Standard ops and types and provides hooks for dialect-specific +// extensions to the conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H +#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace llvm { +class IntegerType; +class LLVMContext; +class Module; +class Type; +} // namespace llvm + +namespace mlir { + +class UnrankedMemRefType; + +namespace LLVM { +class LLVMDialect; +class LLVMType; +} // namespace LLVM + +/// Conversion from types in the Standard dialect to the LLVM IR dialect. +class LLVMTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + LLVMTypeConverter(MLIRContext *ctx); + + /// Convert types to LLVM IR. This calls `convertAdditionalType` to convert + /// non-standard or non-builtin types. + Type convertType(Type t) override; + + /// Convert a function type. The arguments and results are converted one by + /// one and results are packed into a wrapped LLVM IR structure type. `result` + /// is populated with argument mapping. + LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, + SignatureConversion &result); + + /// Convert a non-empty list of types to be returned from a function into a + /// supported LLVM IR type. In particular, if more than one values is + /// returned, create an LLVM IR structure type with elements that correspond + /// to each of the MLIR types converted with `convertType`. + Type packFunctionResults(ArrayRef types); + + /// Returns the LLVM context. + llvm::LLVMContext &getLLVMContext(); + + /// Returns the LLVM dialect. + LLVM::LLVMDialect *getDialect() { return llvmDialect; } + + /// Promote the LLVM struct representation of all MemRef descriptors to stack + /// and use pointers to struct to avoid the complexity of the + /// platform-specific C/C++ ABI lowering related to struct argument passing. + SmallVector promoteMemRefDescriptors(Location loc, + ValueRange opOperands, + ValueRange operands, + OpBuilder &builder); + + /// Promote the LLVM struct representation of one MemRef descriptor to stack + /// and use pointer to struct to avoid the complexity of the platform-specific + /// C/C++ ABI lowering related to struct argument passing. + Value promoteOneMemRefDescriptor(Location loc, Value operand, + OpBuilder &builder); + +protected: + /// LLVM IR module used to parse/create types. + llvm::Module *module; + LLVM::LLVMDialect *llvmDialect; + +private: + Type convertStandardType(Type type); + + // Convert a function type. The arguments and results are converted one by + // one. Additionally, if the function returns more than one value, pack the + // results into an LLVM IR structure type so that the converted function type + // returns at most one result. + Type convertFunctionType(FunctionType type); + + // Convert the index type. Uses llvmModule data layout to create an integer + // of the pointer bitwidth. + Type convertIndexType(IndexType type); + + // Convert an integer type `i*` to `!llvm<"i*">`. + Type convertIntegerType(IntegerType type); + + // Convert a floating point type: `f16` to `!llvm.half`, `f32` to + // `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported + // by LLVM. + Type convertFloatType(FloatType type); + + // Convert a memref type into an LLVM type that captures the relevant data. + // For statically-shaped memrefs, the resulting type is a pointer to the + // (converted) memref element type. For dynamically-shaped memrefs, the + // resulting type is an LLVM structure type that contains: + // 1. a pointer to the (converted) memref element type + // 2. as many index types as memref has dynamic dimensions. + Type convertMemRefType(MemRefType type); + + // Convert an unranked memref type to an LLVM type that captures the + // runtime rank and a pointer to the static ranked memref desc + Type convertUnrankedMemRefType(UnrankedMemRefType type); + + // Convert a 1D vector type into an LLVM vector type. + Type convertVectorType(VectorType type); + + // Get the LLVM representation of the index type based on the bitwidth of the + // pointer as defined by the data layout of the module. + LLVM::LLVMType getIndexType(); + + // Extract an LLVM IR dialect type. + LLVM::LLVMType unwrap(Type type); +}; + +/// Helper class to produce LLVM dialect operations extracting or inserting +/// values to a struct. +class StructBuilder { +public: + /// Construct a helper for the given value. + explicit StructBuilder(Value v); + /// Builds IR creating an `undef` value of the descriptor type. + static StructBuilder undef(OpBuilder &builder, Location loc, + Type descriptorType); + + /*implicit*/ operator Value() { return value; } + +protected: + // LLVM value + Value value; + // Cached struct type. + Type structType; + +protected: + /// Builds IR to extract a value from the struct at position pos + Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); + /// Builds IR to set a value in the struct at position pos + void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); +}; +/// Helper class to produce LLVM dialect operations extracting or inserting +/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. +/// The Value may be null, in which case none of the operations are valid. +class MemRefDescriptor : public StructBuilder { +public: + /// Construct a helper for the given descriptor value. + explicit MemRefDescriptor(Value descriptor); + /// Builds IR creating an `undef` value of the descriptor type. + static MemRefDescriptor undef(OpBuilder &builder, Location loc, + Type descriptorType); + /// Builds IR creating a MemRef descriptor that represents `type` and + /// populates it with static shape and stride information extracted from the + /// type. + static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + MemRefType type, Value memory); + + /// Builds IR extracting the allocated pointer from the descriptor. + Value allocatedPtr(OpBuilder &builder, Location loc); + /// Builds IR inserting the allocated pointer into the descriptor. + void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); + + /// Builds IR extracting the aligned pointer from the descriptor. + Value alignedPtr(OpBuilder &builder, Location loc); + + /// Builds IR inserting the aligned pointer into the descriptor. + void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); + + /// Builds IR extracting the offset from the descriptor. + Value offset(OpBuilder &builder, Location loc); + + /// Builds IR inserting the offset into the descriptor. + void setOffset(OpBuilder &builder, Location loc, Value offset); + void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); + + /// Builds IR extracting the pos-th size from the descriptor. + Value size(OpBuilder &builder, Location loc, unsigned pos); + + /// Builds IR inserting the pos-th size into the descriptor + void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); + void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, + uint64_t size); + + /// Builds IR extracting the pos-th size from the descriptor. + Value stride(OpBuilder &builder, Location loc, unsigned pos); + + /// Builds IR inserting the pos-th stride into the descriptor + void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); + void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, + uint64_t stride); + + /// Returns the (LLVM) type this descriptor points to. + LLVM::LLVMType getElementType(); + +private: + // Cached index type. + Type indexType; +}; + +class UnrankedMemRefDescriptor : public StructBuilder { +public: + /// Construct a helper for the given descriptor value. + explicit UnrankedMemRefDescriptor(Value descriptor); + /// Builds IR creating an `undef` value of the descriptor type. + static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, + Type descriptorType); + + /// Builds IR extracting the rank from the descriptor + Value rank(OpBuilder &builder, Location loc); + /// Builds IR setting the rank in the descriptor + void setRank(OpBuilder &builder, Location loc, Value value); + /// Builds IR extracting ranked memref descriptor ptr + Value memRefDescPtr(OpBuilder &builder, Location loc); + /// Builds IR setting ranked memref descriptor ptr + void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); +}; +/// Base class for operation conversions targeting the LLVM IR dialect. Provides +/// conversion patterns with an access to the containing LLVMLowering for the +/// purpose of type conversions. +class LLVMOpLowering : public ConversionPattern { +public: + LLVMOpLowering(StringRef rootOpName, MLIRContext *context, + LLVMTypeConverter &lowering, PatternBenefit benefit = 1); + +protected: + // Back-reference to the lowering class, used to call type and function + // conversions accounting for potential extensions. + LLVMTypeConverter &lowering; +}; + +} // namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h new file mode 100644 index 0000000000000000000000000000000000000000..a4d95da6a75d41a8d1a10285888b212071582777 --- /dev/null +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -0,0 +1,109 @@ +//===- ConvertStandardToLLVMPass.h - Pass entrypoint ------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ +#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ + +#include "llvm/ADT/STLExtras.h" +#include +#include + +namespace llvm { +class Module; +} // namespace llvm + +namespace mlir { +class DialectConversion; +class FuncOp; +class LLVMTypeConverter; +struct LogicalResult; +class MLIRContext; +class ModuleOp; +template class OpPassBase; +class RewritePattern; +class Type; + +// Owning list of rewriting patterns. +class OwningRewritePatternList; + +/// Type for a callback constructing the owning list of patterns for the +/// conversion to the LLVMIR dialect. The callback is expected to append +/// patterns to the owning list provided as the second argument. +using LLVMPatternListFiller = + std::function; + +/// Type for a callback constructing the type converter for the conversion to +/// the LLVMIR dialect. The callback is expected to return an instance of the +/// converter. +using LLVMTypeConverterMaker = + std::function(MLIRContext *)>; + +/// Collect a set of patterns to convert memory-related operations from the +/// Standard dialect to the LLVM dialect, excluding the memory-related +/// operations. +void populateStdToLLVMMemoryConversionPatters( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to the LLVM +/// dialect, excluding the memory-related operations. +void populateStdToLLVMNonMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to LLVM. +void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Creates a pass to convert the Standard dialect into the LLVMIR dialect. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. +std::unique_ptr> +createLowerToLLVMPass(bool useAlloca = false); + +/// Creates a pass to convert operations to the LLVMIR dialect. The conversion +/// is defined by a list of patterns and a type converter that will be obtained +/// during the pass using the provided callbacks. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. +std::unique_ptr> +createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, + LLVMTypeConverterMaker typeConverterMaker, + bool useAlloca = false); + +/// Creates a pass to convert operations to the LLVMIR dialect. The conversion +/// is defined by a list of patterns obtained during the pass using the provided +/// callback and an optional type conversion class, an instance is created +/// during the pass. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. +template +std::unique_ptr> +createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, + bool useAlloca = false) { + return createLowerToLLVMPass( + patternListFiller, + [](MLIRContext *context) { + return std::make_unique(context); + }, + useAlloca); +} + +namespace LLVM { +/// Make argument-taking successors of each block distinct. PHI nodes in LLVM +/// IR use the predecessor ID to identify which value to take. They do not +/// support different values coming from the same predecessor. If a block has +/// another block as a successor more than once with different values, insert +/// a new dummy block for LLVM PHI nodes to tell the sources apart. +void ensureDistinctSuccessors(ModuleOp m); +} // namespace LLVM + +} // namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h new file mode 100644 index 0000000000000000000000000000000000000000..e0e874027bf443d4737a9d384c99e421a298d186 --- /dev/null +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h @@ -0,0 +1,35 @@ +//===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to lower StandardOps to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H +#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating StandardOps to +/// SPIR-V ops. Also adds the patterns legalize ops not directly translated to +/// SPIR-V dialect. +void populateStandardToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + +/// Appends to a pattern list patterns to legalize ops that are not directly +/// lowered to SPIR-V. +void populateStdLegalizationPatternsForSPIRVLowering( + MLIRContext *context, OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h new file mode 100644 index 0000000000000000000000000000000000000000..7dbaf1c04188d7c8ac5895cbe988ebfd182f54a7 --- /dev/null +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h @@ -0,0 +1,28 @@ +//===- ConvertStandardToSPIRVPass.h - StdOps to SPIR-V pass -----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a pass to lower from StandardOps to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H +#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Pass to convert StandardOps to SPIR-V ops. +std::unique_ptr> createConvertStandardToSPIRVPass(); + +/// Pass to legalize ops that are not directly lowered to SPIR-V. +std::unique_ptr createLegalizeStdOpsForSPIRVLoweringPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h new file mode 100644 index 0000000000000000000000000000000000000000..b8b97c21a3efbf9ad398fddba199c3ad65671812 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -0,0 +1,27 @@ +//===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ +#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class LLVMTypeConverter; +class ModuleOp; +template class OpPassBase; + +/// Collect a set of patterns to convert from the Vector dialect to LLVM. +void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert vector operations to the LLVMIR dialect. +OpPassBase *createLowerVectorToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h b/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h new file mode 100644 index 0000000000000000000000000000000000000000..4f7d0843b7326211d61028afe2cf805ecaa00752 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h @@ -0,0 +1,27 @@ +//===- ConvertVectorToLoops.h - Utils to convert from the vector dialect --===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ +#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class MLIRContext; +class ModuleOp; +template class OpPassBase; + +/// Collect a set of patterns to convert from the Vector dialect to loops + std. +void populateVectorToAffineLoopsConversionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + +/// Create a pass to convert vector operations to affine loops + std dialect. +OpPassBase *createLowerVectorToLoopsPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h new file mode 100644 index 0000000000000000000000000000000000000000..b884ac5c2cea4f154e15e309bc719db9c4424a8e --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -0,0 +1,677 @@ +//===- AffineOps.h - MLIR Affine Operations -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines convenience types for working with Affine operations +// in the MLIR operation set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AFFINEOPS_AFFINEOPS_H +#define MLIR_DIALECT_AFFINEOPS_AFFINEOPS_H + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/LoopLikeInterface.h" + +namespace mlir { +class AffineBound; +class AffineDimExpr; +class AffineValueMap; +class AffineTerminatorOp; +class FlatAffineConstraints; +class OpBuilder; + +/// A utility function to check if a value is defined at the top level of a +/// function. A value of index type defined at the top level is always a valid +/// symbol. +bool isTopLevelValue(Value value); + +class AffineOpsDialect : public Dialect { +public: + AffineOpsDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "affine"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +/// The "affine.apply" operation applies an affine map to a list of operands, +/// yielding a single result. The operand list must be the same size as the +/// number of arguments to the affine mapping. All operands and the result are +/// of type 'Index'. This operation requires a single affine map attribute named +/// "map". For example: +/// +/// %y = "affine.apply" (%x) { map: (d0) -> (d0 + 1) } : +/// (index) -> (index) +/// +/// equivalently: +/// +/// #map42 = (d0)->(d0+1) +/// %y = affine.apply #map42(%x) +/// +class AffineApplyOp : public Op { +public: + using Op::Op; + + /// Builds an affine apply op with the specified map and operands. + static void build(Builder *builder, OperationState &result, AffineMap map, + ValueRange operands); + + /// Returns the affine map to be applied by this operation. + AffineMap getAffineMap() { + return getAttrOfType("map").getValue(); + } + + /// Returns true if the result of this operation can be used as dimension id. + bool isValidDim(); + + /// Returns true if the result of this operation is a symbol. + bool isValidSymbol(); + + static StringRef getOperationName() { return "affine.apply"; } + + operand_range getMapOperands() { return getOperands(); } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + OpFoldResult fold(ArrayRef operands); + + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); +}; + +/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data +/// from a source memref to a destination memref. The source and destination +/// memref need not be of the same dimensionality, but need to have the same +/// elemental type. The operands include the source and destination memref's +/// each followed by its indices, size of the data transfer in terms of the +/// number of elements (of the elemental type of the memref), a tag memref with +/// its indices, and optionally at the end, a stride and a +/// number_of_elements_per_stride arguments. The tag location is used by an +/// AffineDmaWaitOp to check for completion. The indices of the source memref, +/// destination memref, and the tag memref have the same restrictions as any +/// affine.load/store. In particular, index for each memref dimension must be an +/// affine expression of loop induction variables and symbols. +/// The optional stride arguments should be of 'index' type, and specify a +/// stride for the slower memory space (memory space with a lower memory space +/// id), transferring chunks of number_of_elements_per_stride every stride until +/// %num_elements are transferred. Either both or no stride arguments should be +/// specified. The value of 'num_elements' must be a multiple of +/// 'number_of_elements_per_stride'. +// +// For example, a DmaStartOp operation that transfers 256 elements of a memref +// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory +// space 1 at indices [%k + 7, %l], would be specified as follows: +// +// %num_elements = constant 256 +// %idx = constant 0 : index +// %tag = alloc() : memref<1xi32, 4> +// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], +// %num_elements : +// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> +// +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, +// %stride, %num_elt_per_stride : ... +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels (possibly using AffineMaps to specify +// multiple levels of striding). +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class AffineDmaStartOp : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState &result, Value srcMemRef, + AffineMap srcMap, ValueRange srcIndices, Value destMemRef, + AffineMap dstMap, ValueRange destIndices, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, Value numElements, + Value stride = nullptr, Value elementsPerStride = nullptr); + + /// Returns the operand index of the src memref. + unsigned getSrcMemRefOperandIndex() { return 0; } + + /// Returns the source MemRefType for this DMA operation. + Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } + MemRefType getSrcMemRefType() { + return getSrcMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } + + /// Returns the affine map used to access the src memref. + AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } + AffineMapAttr getSrcMapAttr() { + return getAttr(getSrcMapAttrName()).cast(); + } + + /// Returns the source memref affine map indices for this DMA operation. + operand_range getSrcIndices() { + return {operand_begin() + getSrcMemRefOperandIndex() + 1, + operand_begin() + getSrcMemRefOperandIndex() + 1 + + getSrcMap().getNumInputs()}; + } + + /// Returns the memory space of the src memref. + unsigned getSrcMemorySpace() { + return getSrcMemRef()->getType().cast().getMemorySpace(); + } + + /// Returns the operand index of the dst memref. + unsigned getDstMemRefOperandIndex() { + return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); + } + + /// Returns the destination MemRefType for this DMA operations. + Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } + MemRefType getDstMemRefType() { + return getDstMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() { + return getDstMemRef()->getType().cast().getRank(); + } + + /// Returns the memory space of the src memref. + unsigned getDstMemorySpace() { + return getDstMemRef()->getType().cast().getMemorySpace(); + } + + /// Returns the affine map used to access the dst memref. + AffineMap getDstMap() { return getDstMapAttr().getValue(); } + AffineMapAttr getDstMapAttr() { + return getAttr(getDstMapAttrName()).cast(); + } + + /// Returns the destination memref indices for this DMA operation. + operand_range getDstIndices() { + return {operand_begin() + getDstMemRefOperandIndex() + 1, + operand_begin() + getDstMemRefOperandIndex() + 1 + + getDstMap().getNumInputs()}; + } + + /// Returns the operand index of the tag memref. + unsigned getTagMemRefOperandIndex() { + return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); + } + + /// Returns the Tag MemRef for this DMA operation. + Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } + MemRefType getTagMemRefType() { + return getTagMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + /// Returns the affine map used to access the tag memref. + AffineMap getTagMap() { return getTagMapAttr().getValue(); } + AffineMapAttr getTagMapAttr() { + return getAttr(getTagMapAttrName()).cast(); + } + + /// Returns the tag memref indices for this DMA operation. + operand_range getTagIndices() { + return {operand_begin() + getTagMemRefOperandIndex() + 1, + operand_begin() + getTagMemRefOperandIndex() + 1 + + getTagMap().getNumInputs()}; + } + + /// Returns the number of elements being transferred by this DMA operation. + Value getNumElements() { + return getOperand(getTagMemRefOperandIndex() + 1 + + getTagMap().getNumInputs()); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value memref) { + if (memref == getSrcMemRef()) + return {Identifier::get(getSrcMapAttrName(), getContext()), + getSrcMapAttr()}; + else if (memref == getDstMemRef()) + return {Identifier::get(getDstMapAttrName(), getContext()), + getDstMapAttr()}; + assert(memref == getTagMemRef() && + "DmaStartOp expected source, destination or tag memref"); + return {Identifier::get(getTagMapAttrName(), getContext()), + getTagMapAttr()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower one. + bool isDestMemorySpaceFaster() { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster one. + bool isSrcMemorySpaceFaster() { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); + } + + static StringRef getSrcMapAttrName() { return "src_map"; } + static StringRef getDstMapAttrName() { return "dst_map"; } + static StringRef getTagMapAttrName() { return "tag_map"; } + + static StringRef getOperationName() { return "affine.dma_start"; } + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); + + /// Returns true if this DMA operation is strided, returns false otherwise. + bool isStrided() { + return getNumOperands() != + getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; + } + + /// Returns the stride value for this DMA operation. + Value getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + + /// Returns the number of elements to transfer per stride for this DMA op. + Value getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } +}; + +/// AffineDmaWaitOp blocks until the completion of a DMA operation associated +/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be +/// an index with the same restrictions as any load/store index. In particular, +/// index for each memref dimension must be an affine expression of loop +/// induction variables and symbols. %num_elements is the number of elements +/// associated with the DMA operation. For example: +// +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : +// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> +// ... +// ... +// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> +// +class AffineDmaWaitOp : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState &result, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, Value numElements); + + static StringRef getOperationName() { return "affine.dma_wait"; } + + // Returns the Tag MemRef associated with the DMA operation being waited on. + Value getTagMemRef() { return getOperand(0); } + MemRefType getTagMemRefType() { + return getTagMemRef()->getType().cast(); + } + + /// Returns the affine map used to access the tag memref. + AffineMap getTagMap() { return getTagMapAttr().getValue(); } + AffineMapAttr getTagMapAttr() { + return getAttr(getTagMapAttrName()).cast(); + } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + return {operand_begin() + 1, + operand_begin() + 1 + getTagMap().getNumInputs()}; + } + + // Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value memref) { + assert(memref == getTagMemRef()); + return {Identifier::get(getTagMapAttrName(), getContext()), + getTagMapAttr()}; + } + + /// Returns the number of elements transferred in the associated DMA op. + Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } + + static StringRef getTagMapAttrName() { return "tag_map"; } + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); +}; + +/// The "affine.load" op reads an element from a memref, where the index +/// for each memref dimension is an affine expression of loop induction +/// variables and symbols. The output of 'affine.load' is a new value with the +/// same type as the elements of the memref. An affine expression of loop IVs +/// and symbols must be specified for each dimension of the memref. The keyword +/// 'symbol' can be used to indicate SSA identifiers which are symbolic. +// +// Example 1: +// +// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// +// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. +// +// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] +// : memref<100x100xf32> +// +class AffineLoadOp : public Op::Impl> { +public: + using Op::Op; + + /// Builds an affine load op with the specified map and operands. + static void build(Builder *builder, OperationState &result, AffineMap map, + ValueRange operands); + /// Builds an affine load op with an identity map and operands. + static void build(Builder *builder, OperationState &result, Value memref, + ValueRange indices = {}); + /// Builds an affine load op with the specified map and its operands. + static void build(Builder *builder, OperationState &result, Value memref, + AffineMap map, ValueRange mapOperands); + + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 0; } + + /// Get memref operand. + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + /// Get affine map operands. + operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } + static StringRef getOperationName() { return "affine.load"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + OpFoldResult fold(ArrayRef operands); +}; + +/// The "affine.store" op writes an element to a memref, where the index +/// for each memref dimension is an affine expression of loop induction +/// variables and symbols. The 'affine.store' op stores a new value which is the +/// same type as the elements of the memref. An affine expression of loop IVs +/// and symbols must be specified for each dimension of the memref. The keyword +/// 'symbol' can be used to indicate SSA identifiers which are symbolic. +// +// Example 1: +// +// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// +// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. +// +// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] +// : memref<100x100xf32> +// +class AffineStoreOp : public Op::Impl> { +public: + using Op::Op; + + /// Builds an affine store operation with the provided indices (identity map). + static void build(Builder *builder, OperationState &result, + Value valueToStore, Value memref, ValueRange indices); + /// Builds an affine store operation with the specified map and its operands. + static void build(Builder *builder, OperationState &result, + Value valueToStore, Value memref, AffineMap map, + ValueRange mapOperands); + + /// Get value to be stored by store operation. + Value getValueToStore() { return getOperand(0); } + + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 1; } + + /// Get memref operand. + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } + + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + /// Get affine map operands. + operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } + static StringRef getOperationName() { return "affine.store"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); +}; + +/// Returns true if the given Value can be used as a dimension id. +bool isValidDim(Value value); + +/// Returns true if the given Value can be used as a symbol. +bool isValidSymbol(Value value); + +/// Modifies both `map` and `operands` in-place so as to: +/// 1. drop duplicate operands +/// 2. drop unused dims and symbols from map +/// 3. promote valid symbols to symbolic operands in case they appeared as +/// dimensional operands +/// 4. propagate constant operands and drop them +void canonicalizeMapAndOperands(AffineMap *map, + SmallVectorImpl *operands); +/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does +/// for affine maps. +void canonicalizeSetAndOperands(IntegerSet *set, + SmallVectorImpl *operands); + +/// Returns a composed AffineApplyOp by composing `map` and `operands` with +/// other AffineApplyOps supplying those operands. The operands of the resulting +/// AffineApplyOp do not change the length of AffineApplyOp chains. +AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, + ArrayRef operands); + +/// Given an affine map `map` and its input `operands`, this method composes +/// into `map`, maps of AffineApplyOps whose results are the values in +/// `operands`, iteratively until no more of `operands` are the result of an +/// AffineApplyOp. When this function returns, `map` becomes the composed affine +/// map, and each Value in `operands` is guaranteed to be either a loop IV or a +/// terminal symbol, i.e., a symbol defined at the top level or a block/function +/// argument. +void fullyComposeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl *operands); + +#define GET_OP_CLASSES +#include "mlir/Dialect/AffineOps/AffineOps.h.inc" + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool isForInductionVar(Value val); + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +AffineForOp getForInductionVarOwner(Value val); + +/// Extracts the induction variables from a list of AffineForOps and places them +/// in the output argument `ivs`. +void extractForInductionVars(ArrayRef forInsts, + SmallVectorImpl *ivs); + +/// AffineBound represents a lower or upper bound in the for operation. +/// This class does not own the underlying operands. Instead, it refers +/// to the operands stored in the AffineForOp. Its life span should not exceed +/// that of the for operation it refers to. +class AffineBound { +public: + AffineForOp getAffineForOp() { return op; } + AffineMap getMap() { return map; } + + /// Returns an AffineValueMap representing this bound. + AffineValueMap getAsAffineValueMap(); + + unsigned getNumOperands() { return opEnd - opStart; } + Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); } + + using operand_iterator = AffineForOp::operand_iterator; + using operand_range = AffineForOp::operand_range; + + operand_iterator operand_begin() { return op.operand_begin() + opStart; } + operand_iterator operand_end() { return op.operand_begin() + opEnd; } + operand_range getOperands() { return {operand_begin(), operand_end()}; } + +private: + // 'affine.for' operation that contains this bound. + AffineForOp op; + // Start and end positions of this affine bound operands in the list of + // the containing 'affine.for' operation operands. + unsigned opStart, opEnd; + // Affine map for this bound. + AffineMap map; + + AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) + : op(op), opStart(opStart), opEnd(opEnd), map(map) {} + + friend class AffineForOp; +}; + +/// An `AffineApplyNormalizer` is a helper class that supports renumbering +/// operands of AffineApplyOp. This acts as a reindexing map of Value to +/// positional dims or symbols and allows simplifications such as: +/// +/// ```mlir +/// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) +/// ``` +/// +/// into: +/// +/// ```mlir +/// %1 = affine.apply () -> (0) +/// ``` +struct AffineApplyNormalizer { + AffineApplyNormalizer(AffineMap map, ArrayRef operands); + + /// Returns the AffineMap resulting from normalization. + AffineMap getAffineMap() { return affineMap; } + + SmallVector getOperands() { + SmallVector res(reorderedDims); + res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); + return res; + } + + unsigned getNumSymbols() { return concatenatedSymbols.size(); } + unsigned getNumDims() { return reorderedDims.size(); } + + /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this + /// normalizer's coordinate space. + void normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands); + +private: + /// Helper function to insert `v` into the coordinate system of the current + /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding + /// renumbered position. + AffineDimExpr renumberOneDim(Value v); + + /// Given an `other` normalizer, this rewrites `other.affineMap` in the + /// coordinate system of the current AffineApplyNormalizer. + /// Returns the rewritten AffineMap and updates the dims and symbols of + /// `this`. + AffineMap renumber(const AffineApplyNormalizer &other); + + /// Maps of Value to position in `affineMap`. + DenseMap dimValueToPosition; + + /// Ordered dims and symbols matching positional dims and symbols in + /// `affineMap`. + SmallVector reorderedDims; + SmallVector concatenatedSymbols; + + AffineMap affineMap; + + /// Used with RAII to control the depth at which AffineApply are composed + /// recursively. Only accepts depth 1 for now to allow a behavior where a + /// newly composed AffineApplyOp does not increase the length of the chain of + /// AffineApplyOps. Full composition is implemented iteratively on top of + /// this behavior. + static unsigned &affineApplyDepth() { + static thread_local unsigned depth = 0; + return depth; + } + static constexpr unsigned kMaxAffineApplyDepth = 1; + + AffineApplyNormalizer() { affineApplyDepth()++; } + +public: + ~AffineApplyNormalizer() { affineApplyDepth()--; } +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td new file mode 100644 index 0000000000000000000000000000000000000000..114e20513b2ae0f9721e34e683ba759cbd5af1b0 --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -0,0 +1,350 @@ +//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines MLIR affine operations. +// +//===----------------------------------------------------------------------===// + +#ifndef AFFINE_OPS +#define AFFINE_OPS + +include "mlir/Dialect/AffineOps/AffineOpsBase.td" +include "mlir/IR/OpBase.td" +include "mlir/Transforms/LoopLikeInterface.td" + +def Affine_Dialect : Dialect { + let name = "affine"; + let cppNamespace = ""; +} + +// Base class for Affine dialect ops. +class Affine_Op traits = []> : + Op { + // For every affine op, there needs to be a: + // * void print(OpAsmPrinter &p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, + // OperationState &result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +// Require regions to have affine terminator. +def ImplicitAffineTerminator + : SingleBlockImplicitTerminator<"AffineTerminatorOp">; + +def AffineForOp : Affine_Op<"for", + [ImplicitAffineTerminator, + DeclareOpInterfaceMethods]> { + let summary = "for operation"; + let description = [{ + The "affine.for" operation represents an affine loop nest, defining an SSA + value for its induction variable. It has one region capturing the loop body. + The induction variable is represented as a argument of this region. This SSA + value always has type index, which is the size of the machine word. The + stride, represented by step, is a positive constant integer which defaults + to "1" if not present. The lower and upper bounds specify a half-open range: + the range includes the lower bound but does not include the upper bound. + + The body region must contain exactly one block that terminates with + "affine.terminator". Calling AffineForOp::build will create such region + and insert the terminator, so will the parsing even in cases if it is absent + from the custom format. + + The lower and upper bounds of a for operation are represented as an + application of an affine mapping to a list of SSA values passed to the map. + The same restrictions hold for these SSA values as for all bindings of SSA + values to dimensions and symbols. The affine mappings for the bounds may + return multiple results, in which case the max/min keywords are required + (for the lower/upper bound respectively), and the bound is the + maximum/minimum of the returned values. + + Example: + + affine.for %i = 1 to 10 { + ... + } + + }]; + let arguments = (ins Variadic); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "int64_t lowerBound, int64_t upperBound, int64_t step = 1">, + OpBuilder<"Builder *builder, OperationState &result, " + "ValueRange lbOperands, AffineMap lbMap, " + "ValueRange ubOperands, AffineMap ubMap, " + "int64_t step = 1"> + ]; + + let extraClassDeclaration = [{ + static StringRef getStepAttrName() { return "step"; } + static StringRef getLowerBoundAttrName() { return "lower_bound"; } + static StringRef getUpperBoundAttrName() { return "upper_bound"; } + + Block *getBody() { return ®ion().front(); } + Value getInductionVar() { return getBody()->getArgument(0); } + OpBuilder getBodyBuilder() { + return OpBuilder(getBody(), std::prev(getBody()->end())); + } + + // TODO: provide iterators for the lower and upper bound operands + // if the current access via getLowerBound(), getUpperBound() is too slow. + + /// Returns operands for the lower bound map. + operand_range getLowerBoundOperands(); + + /// Returns operands for the upper bound map. + operand_range getUpperBoundOperands(); + + /// Returns information about the lower bound as a single object. + AffineBound getLowerBound(); + + /// Returns information about the upper bound as a single object. + AffineBound getUpperBound(); + + /// Returns loop step. + int64_t getStep() { + return getAttr(getStepAttrName()).cast().getInt(); + } + + /// Returns affine map for the lower bound. + AffineMap getLowerBoundMap() { return getLowerBoundMapAttr().getValue(); } + AffineMapAttr getLowerBoundMapAttr() { + return getAttr(getLowerBoundAttrName()).cast(); + } + /// Returns affine map for the upper bound. The upper bound is exclusive. + AffineMap getUpperBoundMap() { return getUpperBoundMapAttr().getValue(); } + AffineMapAttr getUpperBoundMapAttr() { + return getAttr(getUpperBoundAttrName()).cast(); + } + + /// Set lower bound. The new bound must have the same number of operands as + /// the current bound map. Otherwise, 'replaceForLowerBound' should be used. + void setLowerBound(ValueRange operands, AffineMap map); + /// Set upper bound. The new bound must not have more operands than the + /// current bound map. Otherwise, 'replaceForUpperBound' should be used. + void setUpperBound(ValueRange operands, AffineMap map); + + /// Set the lower bound map without changing operands. + void setLowerBoundMap(AffineMap map); + + /// Set the upper bound map without changing operands. + void setUpperBoundMap(AffineMap map); + + /// Set loop step. + void setStep(int64_t step) { + assert(step > 0 && "step has to be a positive integer constant"); + auto *context = getLowerBoundMap().getContext(); + setAttr(Identifier::get(getStepAttrName(), context), + IntegerAttr::get(IndexType::get(context), step)); + } + + /// Returns true if the lower bound is constant. + bool hasConstantLowerBound(); + /// Returns true if the upper bound is constant. + bool hasConstantUpperBound(); + /// Returns true if both bounds are constant. + bool hasConstantBounds() { + return hasConstantLowerBound() && hasConstantUpperBound(); + } + /// Returns the value of the constant lower bound. + /// Fails assertion if the bound is non-constant. + int64_t getConstantLowerBound(); + /// Returns the value of the constant upper bound. The upper bound is + /// exclusive. Fails assertion if the bound is non-constant. + int64_t getConstantUpperBound(); + /// Sets the lower bound to the given constant value. + void setConstantLowerBound(int64_t value); + /// Sets the upper bound to the given constant value. + void setConstantUpperBound(int64_t value); + + /// Returns true if both the lower and upper bound have the same operand + /// lists (same operands in the same order). + bool matchingBoundOperandList(); + }]; + + let hasCanonicalizer = 1; + let hasFolder = 1; +} + +def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { + let summary = "if-then-else operation"; + let description = [{ + The "if" operation represents an if-then-else construct for conditionally + executing two regions of code. The operands to an if operation are an + IntegerSet condition and a set of symbol/dimension operands to the + condition set. The operation produces no results. For example: + + affine.if #set(%i) { + ... + } else { + ... + } + + The 'else' blocks to the if operation are optional, and may be omitted. For + example: + + affine.if #set(%i) { + ... + } + }]; + let arguments = (ins Variadic); + let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "IntegerSet set, ValueRange args, bool withElseRegion"> + ]; + + let extraClassDeclaration = [{ + static StringRef getConditionAttrName() { return "condition"; } + + IntegerSet getIntegerSet(); + void setIntegerSet(IntegerSet newSet); + + /// Sets the integer set with its operands. The size of 'operands' must not + /// exceed the current number of operands for this instance, as the operands + /// list of AffineIf is not resizable. + void setConditional(IntegerSet set, ValueRange operands); + + OpBuilder getThenBodyBuilder() { + assert(!thenRegion().empty() && "Unexpected empty 'then' region."); + Block &body = thenRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + OpBuilder getElseBodyBuilder() { + assert(!elseRegion().empty() && "Unexpected empty 'else' region."); + Block &body = elseRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + }]; + + let hasFolder = 1; +} + +def AffineMinOp : Affine_Op<"min"> { + let summary = "min operation"; + let description = [{ + The "min" operation computes the minimum value result from a multi-result + affine map. + + Example: + + %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) : index + }]; + let arguments = (ins AffineMapAttr:$map, Variadic:$operands); + let results = (outs Index); + let extraClassDeclaration = [{ + static StringRef getMapAttrName() { return "map"; } + }]; + let hasFolder = 1; +} + +def AffinePrefetchOp : Affine_Op<"prefetch"> { + let summary = "affine prefetch operation"; + let description = [{ + The "affine.prefetch" op prefetches data from a memref location described + with an affine subscript similar to affine.load, and has three attributes: + a read/write specifier, a locality hint, and a cache type specifier as shown + below: + + affine.prefetch %0[%i, %j + 5], read, locality<3>, data + : memref<400x400xi32> + + The read/write specifier is either 'read' or 'write', the locality hint + specifier ranges from locality<0> (no locality) to locality<3> (extremely + local keep in cache). The cache type specifier is either 'data' or 'instr' + and specifies whether the prefetch is performed on data cache or on + instruction cache. + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices, + BoolAttr:$isWrite, + Confined, + IntMaxValue<3>]>:$localityHint, + BoolAttr:$isDataCache); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value memref," + "AffineMap map, ArrayRef mapOperands, bool isWrite," + "unsigned localityHint, bool isDataCache", + [{ + assert(map.getNumInputs() == mapOperands.size() + && "inconsistent index info"); + auto localityHintAttr = builder->getI32IntegerAttr(localityHint); + auto isWriteAttr = builder->getBoolAttr(isWrite); + auto isDataCacheAttr = builder->getBoolAttr(isDataCache); + result.addOperands(memref); + result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); + result.addOperands(mapOperands); + result.addAttribute(getLocalityHintAttrName(), localityHintAttr); + result.addAttribute(getIsWriteAttrName(), isWriteAttr); + result.addAttribute(getIsDataCacheAttrName(), isDataCacheAttr); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref()->getType().cast(); + } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value mref) { + assert(mref == memref()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + /// Get affine map operands. + operand_range getMapOperands() { + return {operand_begin() + 1, operand_end()}; + } + + static StringRef getMapAttrName() { return "map"; } + static StringRef getLocalityHintAttrName() { return "localityHint"; } + static StringRef getIsWriteAttrName() { return "isWrite"; } + static StringRef getIsDataCacheAttrName() { return "isDataCache"; } + }]; + + let hasCanonicalizer = 1; + let hasFolder = 1; +} + +def AffineTerminatorOp : + Affine_Op<"terminator", [Terminator]> { + let summary = "affine terminator operation"; + let description = [{ + Affine terminator is a special terminator operation for blocks inside affine + loops and branches. It unconditionally transmits the control flow to the + successor of the operation enclosing the region. + + This operation does _not_ have a custom syntax. However, affine control + operations omit the terminator in their custom syntax for brevity. + }]; + + // No custom parsing/printing form. + let parser = ?; + let printer = ?; + + // Fully specified by traits. + let verifier = ?; +} + +#endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td new file mode 100644 index 0000000000000000000000000000000000000000..6aee5f3cd4a51fcc67806a2d263a324dbe3aa50a --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td @@ -0,0 +1,31 @@ +//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines base support for MLIR affine operations. +// +//===----------------------------------------------------------------------===// + +#ifndef AFFINE_OPS_BASE +#define AFFINE_OPS_BASE + +include "mlir/IR/OpBase.td" + +// Attributes containing affine maps. +def AffineMapAttr : Attr< + CPred<"$_self.isa()">, "AffineMap attribute"> { + let storageType = [{ AffineMapAttr }]; + let returnType = [{ AffineMap }]; + let constBuilderCall = "AffineMapAttr::get($0)"; +} + +def AffineMapArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; +} + +#endif // AFFINE_OPS_BASE diff --git a/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7339bcc9dcfd6406ec3e358487d09caa9edde88c --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(AffineOps AffineOps) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9235436995aea83a6b1736b897f97abb74a7f6bc --- /dev/null +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -0,0 +1,10 @@ +add_subdirectory(AffineOps) +add_subdirectory(FxpMathOps) +add_subdirectory(GPU) +add_subdirectory(Linalg) +add_subdirectory(LLVMIR) +add_subdirectory(LoopOps) +add_subdirectory(QuantOps) +add_subdirectory(SPIRV) +add_subdirectory(StandardOps) +add_subdirectory(VectorOps) diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h new file mode 100644 index 0000000000000000000000000000000000000000..d667de73d4194eb77b5ccc34324b71d4c3e5187f --- /dev/null +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -0,0 +1,73 @@ +//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares various common operation folders. These folders +// are intended to be used by dialects to support common folding behavior +// without requiring each dialect to provide its own implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_COMMONFOLDERS_H +#define MLIR_DIALECT_COMMONFOLDERS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +/// Performs constant folding `calculate` with element-wise behavior on the two +/// attributes in `operands` and returns the result if possible. +template > +Attribute constFoldBinaryOp(ArrayRef operands, + const CalculationT &calculate) { + assert(operands.size() == 2 && "binary op takes two operands"); + if (!operands[0] || !operands[1]) + return {}; + if (operands[0].getType() != operands[1].getType()) + return {}; + + if (operands[0].isa() && operands[1].isa()) { + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + return AttrElementT::get(lhs.getType(), + calculate(lhs.getValue(), rhs.getValue())); + } else if (operands[0].isa() && + operands[1].isa()) { + // Both operands are splats so we can avoid expanding the values out and + // just fold based on the splat value. + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + auto elementResult = calculate(lhs.getSplatValue(), + rhs.getSplatValue()); + return DenseElementsAttr::get(lhs.getType(), elementResult); + } else if (operands[0].isa() && + operands[1].isa()) { + // Operands are ElementsAttr-derived; perform an element-wise fold by + // expanding the values. + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + auto lhsIt = lhs.getValues().begin(); + auto rhsIt = rhs.getValues().begin(); + SmallVector elementResults; + elementResults.reserve(lhs.getNumElements()); + for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) + elementResults.push_back(calculate(*lhsIt, *rhsIt)); + return DenseElementsAttr::get(lhs.getType(), elementResults); + } + return {}; +} +} // namespace mlir + +#endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..484230778b3d78c85050ad3e77184e1ca23df69f --- /dev/null +++ b/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(FxpMathOps FxpMathOps) diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h new file mode 100644 index 0000000000000000000000000000000000000000..8c0e7aa1aadce647cf5df07ddb72cfd19217b9f6 --- /dev/null +++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h @@ -0,0 +1,31 @@ +//===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_ +#define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace fxpmath { + +/// Defines the 'FxpMathOps' dialect. +class FxpMathOpsDialect : public Dialect { +public: + FxpMathOpsDialect(MLIRContext *context); +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc" + +} // namespace fxpmath +} // namespace mlir + +#endif // MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_ diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td new file mode 100644 index 0000000000000000000000000000000000000000..d527b759a10c19c64614a6995f0ade815988ea41 --- /dev/null +++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td @@ -0,0 +1,277 @@ +//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for fixed point ops (and real +// equivalents). +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_FXPMATHOPS_FXPMATH_OPS_ +#define DIALECT_FXPMATHOPS_FXPMATH_OPS_ + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/QuantOps/QuantPredicates.td" + +def fxpmath_Dialect : Dialect { + let name = "fxpmath"; +} + +//===----------------------------------------------------------------------===// +// Attributes +//===----------------------------------------------------------------------===// + +// Real value for an (inclusive) min/max clamp limit. +def fxpmath_ClampValueAttr : OptionalAttr; + +// Element-wise activation function to apply. +// Note that RELU activations are not here: they are expressed as clamps. +def fxpmath_EwUnaryFnAttr : + StringBasedAttr, "element-wise unary function"> { + let returnType = [{ StringRef }]; + let defaultValue = "IDENTITY"; +} + +class fxpmath_ConstEwUnaryFn : ConstantAttr; +def fxpmath_EwUnaryFn_Abs : fxpmath_ConstEwUnaryFn<"ABS">; +def fxpmath_EwUnaryFn_Exp : fxpmath_ConstEwUnaryFn<"EXP">; +def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">; +def fxpmath_EwUnaryFn_Log : fxpmath_ConstEwUnaryFn<"LOG">; +def fxpmath_EwUnaryFn_Neg : fxpmath_ConstEwUnaryFn<"NEG">; +def fxpmath_EwUnaryFn_Rsqrt : fxpmath_ConstEwUnaryFn<"RSQRT">; +def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">; +def fxpmath_EwUnaryFn_Sign : fxpmath_ConstEwUnaryFn<"SIGN">; +def fxpmath_EwUnaryFn_Sin : fxpmath_ConstEwUnaryFn<"SIN">; +def fxpmath_EwUnaryFn_Sqrt : fxpmath_ConstEwUnaryFn<"SQRT">; +def fxpmath_EwUnaryFn_Square : fxpmath_ConstEwUnaryFn<"SQUARE">; +def fxpmath_EwUnaryFn_Tanh : fxpmath_ConstEwUnaryFn<"TANH">; + +//===----------------------------------------------------------------------===// +// Comparison functions (compares relative to zero on a subtraction result). +//===----------------------------------------------------------------------===// + +def fxpmath_CompareZ : StrEnumAttrCase<"CMPZ">; +def fxpmath_CompareNZ : StrEnumAttrCase<"CMPNZ">; +def fxpmath_CompareLZ : StrEnumAttrCase<"CMPLZ">; +def fxpmath_CompareLZE : StrEnumAttrCase<"CMPLZE">; +def fxpmath_CompareGZ : StrEnumAttrCase<"CMPGZ">; +def fxpmath_CompareGZE : StrEnumAttrCase<"CMPGZE">; + +def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn", + "Type of subtraction-result comparison to perform.", + [ + fxpmath_CompareZ, + fxpmath_CompareNZ, + fxpmath_CompareLZ, + fxpmath_CompareLZE, + fxpmath_CompareGZ, + fxpmath_CompareGZE + ]>; + +//===----------------------------------------------------------------------===// +// Base classes +//===----------------------------------------------------------------------===// + +class fxpmath_Op traits> : + Op; + +//===----------------------------------------------------------------------===// +// Fixed-point (fxp) arithmetic ops used by kernels. +// Some of these are temporary pending inclusion into a more core dialect. +//===----------------------------------------------------------------------===// + +def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameOperandsAndResultType]> { + let summary = + "Clamps a signed-integer like argument to a min/max range."; + let description = [{ + Element-wise equivalent to: + r = std::min(clamp_max, std::max(e, clamp_min)) + }]; + let arguments = (ins IntegerLike:$operand, + APIntAttr:$clamp_min, + APIntAttr:$clamp_max); + let results = (outs IntegerLike); +} + +def fxpmath_ConvertISOp : + fxpmath_Op<"convertis", + [NoSideEffect, SameOperandsAndResultShape]> { + let summary = + "Does an element-wise conversion from a signed integer to signed integer"; + let description = [{ + Similar to an element-wise static_cast in C++, from a one signed integer + element type to another. + }]; + let arguments = (ins IntegerLike:$operand); + let results = (outs IntegerLike); +} + +def fxpmath_ConvertISToFOp : + fxpmath_Op<"convertistof", + [NoSideEffect, SameOperandsAndResultShape]> { + let summary = + "Does an element-wise conversion from a signed integer to a float"; + let description = [{ + Similar to an element-wise static_cast in C++, from a signed integer + element type to a floating point element type, rounding to the nearest + floating point value. + }]; + let arguments = (ins IntegerLike:$operand); + let results = (outs FloatLike); +} + + +def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp : + fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis", + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH"; + let description = [{ + Equivalent to the ARMv7 NEON VQRDMULH instruction. + See gemmlowp::SaturatingRoundingDoublingHighMul for a reference + implementation. + }]; + let arguments = (ins IntegerLike:$a, APIntAttr:$b); + let results = (outs IntegerLike); +} + +def fxpmath_RoundingDivideByPotISOp : + fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ + Computes a rounding arithmetic right shift. + }]; + let description = [{ + Computes integer division by a power-of-two, correctly rounded-to-nearest. + Also known as a rounding arithmetic right shift. See + gemmlowp::RoundingDivideByPOT for a reference implementation. + }]; + let arguments = (ins IntegerLike:$operand, APIntAttr:$exponent); + let results = (outs IntegerLike:$res); + let verifier = [{ + auto verifyExponent = exponent().getSExtValue(); + if (verifyExponent < 0 || verifyExponent > 31) { + return emitOpError("exponent must be in range [0..31]"); + } + return success(); + }]; +} + +//===----------------------------------------------------------------------===// +// Real math ops. +// +// Math ops on real numbers which may have a representation in quantized +// arithmetic. It is expected that eligible ops are lowered from a source +// dialect to this set of ops prior to the process of converting a computation +// to a quantized form. It is a non-goal of these ops to preserve enough +// information to convert back to the higher level, source dialect. +// +// These ops support either real/floating point or QuantizedTypes as operands +// and results. Since not all transformations are supported (globally or +// sometimes for specific targets), a computation may end up with +// untransformable RealMathOps, in which case they need to be lowered as is +// (using floating point math). +// +// This op set takes advantage of the fact that it is typically trivial to +// combine a math function with a compatible bias addition and real-valued +// clamp (which can be done at a higher accumulation bit depth). +// +// In addition, all element-wise unary functions are collapsed into a single +// fxpmath_RealUnaryEwOp and selected via an enum-like attribute. Especially at +// low bit depths, this makes matching simpler and allows the construction of +// generic LUT-based implementations. It also allows specific lowering rules +// to consolidate runs of chained unary ops and fuse them to preceding math +// ops, potentially allowing them to operate directly on higher precision +// intermediates without resorting to lots of custom kernels for common +// formulas that can suffer from insufficient precision at low bit depths. +// +// Comparison operators are modeled as element-wise unary functions (i.e. +// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit +// quantized value. It is expected that lowering rules can fuse them with +// the preceding sub. +//===----------------------------------------------------------------------===// + +class fxpmath_RealMathOp traits = [], dag args> : + fxpmath_Op, + Arguments; + +//===----------------------------------------------------------------------===// +// Element wise binary real math ops. +//===----------------------------------------------------------------------===// + +class fxpmath_RealBinaryOp traits = []> : + fxpmath_RealMathOp, + Results<(outs quant_RealValueType:$res)>; + +class fxpmath_RealBinaryBiasOp traits = []> : + fxpmath_RealMathOp, + Results<(outs quant_RealValueType:$res)>; + +def fxpmath_RealAddEwOp : + fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>; + +def fxpmath_RealSubEwOp : + fxpmath_RealBinaryOp<"real_sub_ew", [NoSideEffect]>; + +def fxpmath_RealMulEwOp : + fxpmath_RealBinaryOp<"real_mul_ew", [NoSideEffect]>; + +def fxpmath_RealDivEwOp : + fxpmath_RealBinaryOp<"real_div_ew", [NoSideEffect]>; + +//===----------------------------------------------------------------------===// +// Element wise unary real math op. +//===----------------------------------------------------------------------===// + +def fxpmath_RealUnaryEwOp : + fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect], + (ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>, + Results<(outs quant_RealValueType:$res)>; + +def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>, + Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>, + Results<(outs I1Tensor:$res)> { + let description = [{ + Compares a real value to zero, returning an I1 (boolean) tensor with the + result of applying the comparison function. + }]; +} + +//===----------------------------------------------------------------------===// +// Dot op with fused bias addition. +//===----------------------------------------------------------------------===// + +def fxpmath_RealMatMulOp : + fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> { + let summary = "Matmul"; + let description = [{ + A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is + of shape [n]. Also accepts rank 3 or more input tensors, in which case + the leading dimensions are batch dims. + + Many real systems have specific library calls optimized for this precise + operation, which is why it is handled explicitly versus purely as a + generalized tensor contraction. + }]; +} + +def fxpmath_RealMatMulBiasOp : + fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> { + let summary = "Matmul with bias"; + let description = [{ + A specialization of a RealMatMulOp that also accepts an [n] dimension + bias vector. + + In addition, there is often special support for a fused bias and clamp, + which is why they are included. + }]; +} + +#endif // DIALECT_FXPMATHOPS_FXPMATH_OPS_ diff --git a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..aec21c4c18621b3dfc3f0dbdeab001ac6cf6817f --- /dev/null +++ b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines all of the passes owned by the FxpMathOps dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_FXPMATHOPS_PASSES_H +#define MLIR_DIALECT_FXPMATHOPS_PASSES_H + +namespace mlir { +class FuncOp; +template class OpPassBase; + +namespace fxpmath { + +/// Creates a pass that lowers uniform-quantized real math ops to integer +/// arithmetic. This will leave unrecognized real math ops as-is and is +/// typically followed by a pass that lowers any unrecognized ops to a pure +/// floating point form. +OpPassBase *createLowerUniformRealMathPass(); + +/// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent +/// operations that perform quantize/dequantize. +OpPassBase *createLowerUniformCastsPass(); + +} // namespace fxpmath +} // namespace mlir + +#endif // MLIR_DIALECT_FXPMATHOPS_PASSES_H diff --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fd85b5bcfbfa21463eff0e66ab138f16899d7935 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(GPUOps GPUOps) diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..1776ff7198052fa19087216f2692c6675aac1a41 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -0,0 +1,82 @@ +//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the GPU kernel-related operations and puts them in the +// corresponding dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_GPU_GPUDIALECT_H +#define MLIR_DIALECT_GPU_GPUDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +class FuncOp; + +namespace gpu { + +/// The dialect containing GPU kernel launching operations and related +/// facilities. +class GPUDialect : public Dialect { +public: + /// Create the dialect in the given `context`. + explicit GPUDialect(MLIRContext *context); + /// Get dialect namespace. + static StringRef getDialectNamespace() { return "gpu"; } + + /// Get the name of the attribute used to annotate the modules that contain + /// kernel modules. + static StringRef getContainerModuleAttrName() { + return "gpu.container_module"; + } + + /// Get the canonical string name of the dialect. + static StringRef getDialectName(); + + /// Get the name of the attribute used to annotate external kernel functions. + static StringRef getKernelFuncAttrName() { return "gpu.kernel"; } + + /// Get the name of the attribute used to annotate kernel modules. + static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; } + + /// Returns whether the given function is a kernel function, i.e., has the + /// 'gpu.kernel' attribute. + static bool isKernel(Operation *op); + + /// Returns the numeric value used to identify the workgroup memory address + /// space. + static unsigned getWorkgroupAddressSpace() { return 3; } + + /// Returns the numeric value used to identify the private memory address + /// space. + static unsigned getPrivateAddressSpace() { return 5; } + + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute attr) override; +}; + +/// Utility class for the GPU dialect to represent triples of `Value`s +/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. +struct KernelDim3 { + Value x; + Value y; + Value z; +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/GPU/GPUOps.h.inc" + +} // end namespace gpu +} // end namespace mlir + +#endif // MLIR_DIALECT_GPU_GPUDIALECT_H diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td new file mode 100644 index 0000000000000000000000000000000000000000..b5b93e9b553b58740178f895ad83fe3b4e6f8ce4 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -0,0 +1,587 @@ +//===-- GPUOps.td - GPU dialect operation definitions ------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines some operations of the GPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef GPU_OPS +#define GPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +// Type constraint accepting standard integers, indices and wrapped LLVM integer +// types. +def IntLikeOrLLVMInt : TypeConstraint< + Or<[AnyInteger.predicate, Index.predicate, LLVMInt.predicate]>, + "integer, index or LLVM dialect equivalent">; + +//===----------------------------------------------------------------------===// +// GPU Dialect operations. +//===----------------------------------------------------------------------===// + +def GPU_Dialect : Dialect { + let name = "gpu"; +} + +class GPU_Op traits = []> : + Op; + +class GPU_IndexOp traits = []> : + GPU_Op, + Arguments<(ins StrAttr:$dimension)>, Results<(outs Index)> { + let verifier = [{ return ::verifyIndexOp(*this); }]; +} + +def GPU_BlockDimOp : GPU_IndexOp<"block_dim">; +def GPU_BlockIdOp : GPU_IndexOp<"block_id">; +def GPU_GridDimOp : GPU_IndexOp<"grid_dim">; +def GPU_ThreadIdOp : GPU_IndexOp<"thread_id">; + +def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { + let summary = "Function executable on a GPU"; + + let description = [{ + Defines a function that can be executed on a GPU. This supports memory + attribution and its body has a particular execution model. + + GPU functions are either kernels (as indicated by the `kernel` attribute) or + regular functions. The former can be launched from the host side, while the + latter are device side only. + + The memory attribution defines SSA values that correspond to memory buffers + allocated in the memory hierarchy of the GPU (see below). + + The operation has one attached region that corresponds to the body of the + function. The region arguments consist of the function arguments without + modification, followed by buffers defined in memory annotations. The body of + a GPU function, when launched, is executed by multiple work items. There are + no guarantees on the order in which work items execute, or on the connection + between them. In particular, work items are not necessarily executed in + lock-step. Synchronization ops such as "gpu.barrier" should be used to + coordinate work items. Declarations of GPU functions, i.e. not having the + body region, are not supported. + + Syntax: + + ``` + op ::= `gpu.func` symbol-ref-id `(` argument-list `)` (`->` + function-result-list)? + memory-attribution `kernel`? function-attributes? region + + memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? + (`private` `(` ssa-id-and-type-list `)`)? + ``` + + Example: + + ```mlir + gpu.func @foo(%arg0: index) + workgroup(%workgroup: memref<32xf32, 3>) + private(%private: memref<1xf32, 5>) + kernel + attributes {qux: "quux"} { + gpu.return + } + ``` + + The generic form illustrates the concept + + ```mlir + "gpu.func"(%arg: index) {sym_name: "foo", kernel, qux: "quux"} ({ + ^bb0(%arg0: index, %workgroup: memref<32xf32, 3>, + %private: memref<1xf32, 5>): + "gpu.return"() : () -> () + }) : (index) -> () + ``` + + Note the non-default memory spaces used in memref types in memory + attribution. + }]; + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, StringRef name, " + "FunctionType type, ArrayRef workgroupAttributions = {}, " + "ArrayRef privateAttributions = {}, " + "ArrayRef attrs = {}"> + ]; + + let extraClassDeclaration = [{ + /// Returns `true` if the GPU function defined by this Op is a kernel, i.e. + /// it is intended to be launched from host. + bool isKernel() { + return getAttrOfType(GPUDialect::getKernelFuncAttrName()) != + nullptr; + } + + /// Returns the type of the function this Op defines. + FunctionType getType() { + return getTypeAttr().getValue().cast(); + } + + /// Change the type of this function in place. This is an extremely + /// dangerous operation and it is up to the caller to ensure that this is + /// legal for this function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type + /// has less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + // TODO(b/146349912): consider removing this function thanks to rewrite + // patterns. + void setType(FunctionType newType); + + /// Returns the number of buffers located in the workgroup memory. + unsigned getNumWorkgroupAttributions() { + return getAttrOfType(getNumWorkgroupAttributionsAttrName()) + .getInt(); + } + + /// Returns a list of block arguments that correspond to buffers located in + /// the workgroup memory + ArrayRef getWorkgroupAttributions() { + auto begin = + std::next(getBody().front().args_begin(), getType().getNumInputs()); + auto end = std::next(begin, getNumWorkgroupAttributions()); + return {begin, end}; + } + + /// Returns a list of block arguments that correspond to buffers located in + /// the private memory. + ArrayRef getPrivateAttributions() { + auto begin = + std::next(getBody().front().args_begin(), + getType().getNumInputs() + getNumWorkgroupAttributions()); + return {begin, getBody().front().args_end()}; + } + + /// Returns the name of the attribute containing the number of buffers + /// located in the workgroup memory. + static StringRef getNumWorkgroupAttributionsAttrName() { + return "workgroup_attributions"; + } + + // FunctionLike trait needs access to the functions below. + friend class OpTrait::FunctionLike; + + /// Hooks for the input/output type enumeration in FunctionLike . + unsigned getNumFuncArguments() { return getType().getNumInputs(); } + unsigned getNumFuncResults() { return getType().getNumResults(); } + + /// Returns the keywords used in the custom syntax for this Op. + static StringRef getWorkgroupKeyword() { return "workgroup"; } + static StringRef getPrivateKeyword() { return "private"; } + static StringRef getKernelKeyword() { return "kernel"; } + + /// Hook for FunctionLike verifier. + LogicalResult verifyType(); + + /// Verifies the body of the function. + LogicalResult verifyBody(); + }]; + + // let verifier = [{ return ::verifFuncOpy(*this); }]; + let printer = [{ printGPUFuncOp(p, *this); }]; + let parser = [{ return parseGPUFuncOp(parser, result); }]; +} + +def GPU_LaunchFuncOp : GPU_Op<"launch_func">, + Arguments<(ins IntLikeOrLLVMInt:$gridSizeX, IntLikeOrLLVMInt:$gridSizeY, + IntLikeOrLLVMInt:$gridSizeZ, IntLikeOrLLVMInt:$blockSizeX, + IntLikeOrLLVMInt:$blockSizeY, IntLikeOrLLVMInt:$blockSizeZ, + Variadic:$operands)>, + Results<(outs)> { + let summary = "Launches a function as a GPU kerneel"; + + let description = [{ + Launch a kernel function on the specified grid of thread blocks. + `gpu.launch` operations are lowered to `gpu.launch_func` operations by + outlining the kernel body into a function in a dedicated module, which + reflects the separate compilation process. The kernel function is required + to have the `gpu.kernel` attribute. The module containing the kernel + function is required to have the `gpu.kernel_module` attribute and must be + named. And finally, the module containing the kernel module (which thus + cannot be the top-level module) is required to have the + `gpu.container_module` attribute. The `gpu.launch_func` operation has a + string attribute named `kernel` to specify the name of the kernel function + to launch and an attribute named `kernel_module` to specify the name of the + module containing that kernel function. + + The operation takes at least six operands, with the first three operands + being grid sizes along x,y,z dimensions and the following three being block + sizes along x,y,z dimensions. When a lower-dimensional kernel is required, + unused sizes must be explicitly set to `1`. The remaining operands are + passed as arguments to the kernel function. + + A custom syntax for this operation is currently not available. + + Example: + + ```mlir + module attributes {gpu.container_module} { + + // This module creates a separate compilation unit for the GPU compiler. + module @kernels attributes {gpu.kernel_module} { + func @kernel_1(%arg0 : f32, %arg1 : !llvm<"float*">) + attributes { nvvm.kernel = true } { + + // Operations that produce block/thread IDs and dimensions are + // injected when outlining the `gpu.launch` body to a function called + // by `gpu.launch_func`. + %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) + %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index) + %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index) + + %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) + %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index) + %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index) + + %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index) + %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) + %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index) + + %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index) + %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) + %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) + + "some_op"(%bx, %tx) : (index, index) -> () + %42 = load %arg1[%bx] : memref + } + } + + "gpu.launch_func"(%cst, %cst, %cst, // Grid sizes. + %cst, %cst, %cst, // Block sizes. + %arg0, %arg1) // Arguments passed to the kernel. + { kernel_module = @kernels, // Module containing the kernel. + kernel = "kernel_1" } // Kernel function. + : (index, index, index, index, index, index, f32, !llvm<"float*">) + -> () + } + ``` + }]; + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " + "Value gridSizeX, Value gridSizeY, Value gridSizeZ, " + "Value blockSizeX, Value blockSizeY, Value blockSizeZ, " + "ValueRange kernelOperands">, + OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " + "KernelDim3 gridSize, KernelDim3 blockSize, " + "ValueRange kernelOperands"> + ]; + + let extraClassDeclaration = [{ + /// The kernel function specified by the operation's `kernel` attribute. + StringRef kernel(); + + /// The number of operands passed to the kernel function. + unsigned getNumKernelOperands(); + + /// The name of the kernel module specified by the operation's + /// `kernel_module` attribute. + StringRef getKernelModuleName(); + + /// The i-th operand passed to the kernel function. + Value getKernelOperand(unsigned i); + + /// Get the SSA values passed as operands to specify the grid size. + KernelDim3 getGridSizeOperandValues(); + + /// Get the SSA values passed as operands to specify the block size. + KernelDim3 getBlockSizeOperandValues(); + + /// The number of launch configuration operands, placed at the leading + /// positions of the operand list. + static constexpr unsigned kNumConfigOperands = 6; + + // This needs to quietly verify if attributes with names defined below are + // present since it is run before the verifier of this op. + friend LogicalResult GPUDialect::verifyOperationAttribute(Operation *, + NamedAttribute); + + /// The name of the symbolRef attribute specifying the kernel to launch. + static StringRef getKernelAttrName() { return "kernel"; } + + /// The name of the symbolRef attribute specifying the name of the module + /// containing the kernel to launch. + static StringRef getKernelModuleAttrName() { return "kernel_module"; } + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>, + Arguments<(ins Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ, + Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ, + Variadic:$operands)>, + Results<(outs)> { + let summary = "GPU kernel launch operation"; + + let description = [{ + Launch a kernel on the specified grid of thread blocks. The body of the + kernel is defined by the single region that this operation contains. The + operation takes at least six operands, with first three operands being grid + sizes along x,y,z dimensions, the following three arguments being block + sizes along x,y,z dimension, and the remaining operands are arguments of the + kernel. When a lower-dimensional kernel is required, unused sizes must be + explicitly set to `1`. + + The body region has at least _twelve_ arguments, grouped as follows: + + - three arguments that contain block identifiers along x,y,z dimensions; + - three arguments that contain thread identifiers along x,y,z dimensions; + - operands of the `gpu.launch` operation as is, including six leading + operands for grid and block sizes. + + Operations inside the body region, and any operations in the nested regions, + are _not_ allowed to use values defined outside the _body_ region, as if + this region was a function. If necessary, values must be passed as kernel + arguments into the body region. Nested regions inside the kernel body are + allowed to use values defined in their ancestor regions as long as they + don't cross the kernel body region boundary. + + Syntax: + + ``` + operation ::= `gpu.launch` `block` `(` ssa-id-list `)` `in` ssa-reassignment + `threads` `(` ssa-id-list `)` `in` ssa-reassignment + (`args` ssa-reassignment `:` type-list)? + region attr-dict? + ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` + ``` + + Example: + + ```mlir + gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2) + threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5) + args(%arg0 = %6, %arg1 = 7) : f32, memref { + // Block and thread identifiers, as well as block/grid sizes are + // immediately usable inside body region. + "some_op"(%bx, %tx) : (index, index) -> () + %42 = load %arg1[%bx] : memref + } + + // Generic syntax explains how the pretty syntax maps to the IR structure. + "gpu.launch"(%cst, %cst, %c1, // Grid sizes. + %cst, %c1, %c1, // Block sizes. + %arg0, %arg1) // Actual arguments. + {/*attributes*/} + // All sizes and identifiers have "index" size. + : (index, index, index, index, index, index, f32, memref) + -> () { + // The operation passes block and thread identifiers, followed by grid and + // block sizes, followed by actual arguments to the entry block of the + // region. + ^bb0(%bx : index, %by : index, %bz : index, + %tx : index, %ty : index, %tz : index, + %num_bx : index, %num_by : index, %num_bz : index, + %num_tx : index, %num_ty : index, %num_tz : index, + %arg0 : f32, %arg1 : memref): + "some_op"(%bx, %tx) : (index, index) -> () + %3 = "std.load"(%arg1, %bx) : (memref, index) -> f32 + } + ``` + + Rationale: using operation/block arguments gives analyses a clear way of + understanding that a value has additional semantics (e.g., we will need to + know what value corresponds to threadIdx.x for coalescing). We can recover + these properties by analyzing the operations producing values, but it is + easier just to have that information by construction. + }]; + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, Value gridSizeX," + "Value gridSizeY, Value gridSizeZ, Value blockSizeX," + "Value blockSizeY, Value blockSizeZ," + "ValueRange operands"> + ]; + + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + /// Get the SSA values corresponding to kernel block identifiers. + KernelDim3 getBlockIds(); + /// Get the SSA values corresponding to kernel thread identifiers. + KernelDim3 getThreadIds(); + /// Get the SSA values corresponding to kernel grid size. + KernelDim3 getGridSize(); + /// Get the SSA values corresponding to kernel block size. + KernelDim3 getBlockSize(); + /// Get the operand values passed as kernel arguments. + operand_range getKernelOperandValues(); + /// Get the operand types passed as kernel arguments. + operand_type_range getKernelOperandTypes(); + + /// Get the SSA values passed as operands to specify the grid size. + KernelDim3 getGridSizeOperandValues(); + /// Get the SSA values passed as operands to specify the block size. + KernelDim3 getBlockSizeOperandValues(); + + /// Get the SSA values of the kernel arguments. + iterator_range getKernelArguments(); + + /// Erase the `index`-th kernel argument. Both the entry block argument and + /// the operand will be dropped. The block argument must not have any uses. + void eraseKernelArgument(unsigned index); + + static StringRef getBlocksKeyword() { return "blocks"; } + static StringRef getThreadsKeyword() { return "threads"; } + static StringRef getArgsKeyword() { return "args"; } + + /// The number of launch configuration operands, placed at the leading + /// positions of the operand list. + static constexpr unsigned kNumConfigOperands = 6; + + /// The number of region attributes containing the launch configuration, + /// placed in the leading positions of the argument list. + static constexpr unsigned kNumConfigRegionAttributes = 12; + }]; + + let parser = [{ return parseLaunchOp(parser, result); }]; + let printer = [{ printLaunchOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_ReturnOp : GPU_Op<"return", [Terminator]>, Arguments<(ins)>, + Results<(outs)> { + let summary = "Terminator for GPU launch regions."; + let description = [{ + A terminator operation for regions that appear in the body of `gpu.launch` + operation. These regions are not expected to return any value so the + terminator takes no operands. + }]; + + let parser = [{ return success(); }]; + let printer = [{ p << getOperationName(); }]; +} + +def GPU_YieldOp : GPU_Op<"yield", [Terminator]>, + Arguments<(ins Variadic:$values)> { + let summary = "GPU yield operation"; + let description = [{ + "gpu.yield" is a special terminator operation for blocks inside regions + in gpu ops. It returns values to the immediately enclosing gpu op. + + Example: + + ```gpu.yield %f0, %f1 : f32, f32 + ``` + }]; +} + +// These mirror the XLA ComparisonDirection enum. +def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">; +def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">; + +def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr", + "built-in reduction operations supported by gpu.allreduce.", + [ + GPU_AllReduceOpAdd, + GPU_AllReduceOpMul, + ]>; + +def GPU_AllReduceOp : GPU_Op<"all_reduce", + [SameOperandsAndResultType, IsolatedFromAbove]>, + Arguments<(ins AnyType:$value, + OptionalAttr:$op)>, + Results<(outs AnyType)> { + let summary = "Reduce values among workgroup."; + let description = [{ + The "all_reduce" op reduces the value of every work item across a local + workgroup. The result is equal for all work items of a workgroup. + + For example, both + ``` + %1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32) + %2 = "gpu.all_reduce"(%0) ({ + ^bb(%lhs : f32, %rhs : f32): + %sum = addf %lhs, %rhs : f32 + "gpu.yield"(%sum) : (f32) -> () + }) : (f32) -> (f32) + ``` + compute the sum of each work item's %0 value. The first version specifies + the accumulation as operation, whereas the second version specifies the + accumulation as code region. The accumulation operation must either be + `add` or `mul`. + + Either none or all work items of a workgroup need to execute this op + in convergence. + }]; + let regions = (region AnyRegion:$body); + let verifier = [{ return ::verifyAllReduce(*this); }]; +} + +def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">; + +def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr", + "Indexing modes supported by gpu.shuffle.", + [ + GPU_ShuffleOpXor, + ]>; + +def GPU_ShuffleOp : GPU_Op<"shuffle", [NoSideEffect]>, + Arguments<(ins AnyType:$value, I32:$offset, I32:$width, + GPU_ShuffleModeAttr:$mode)>, + Results<(outs AnyType:$result, I1:$valid)> { + let summary = "Shuffles values within a subgroup."; + let description = [{ + The "shuffle" op moves values to a different invocation within the same + subgroup. + + For example + ``` + %1, %2 = gpu.shuffle %0, %offset, %width xor : f32 + ``` + for lane k returns the value from lane `k ^ offset` and `true` if that lane + is smaller than %width. Otherwise it returns an unspecified value and + `false`. A lane is the index of an invocation relative to its subgroup. + + The width specifies the number of invocations that participate in the + shuffle. The width needs to be the same for all invocations that participate + in the shuffle. Exactly the first `width` invocations of a subgroup need to + execute this op in convergence. + }]; + let verifier = [{ return ::verifyShuffleOp(*this); }]; + let printer = [{ printShuffleOp(p, *this); }]; + let parser = [{ return parseShuffleOp(parser, result); }]; +} + +def GPU_BarrierOp : GPU_Op<"barrier"> { + let summary = "Synchronizes all work items of a workgroup."; + let description = [{ + The "barrier" op synchronizes all work items of a workgroup. It is used + to coordinate communication between the work items of the workgroup. + + ``` + gpu.barrier + ``` + waits until all work items in the workgroup have reached this point + and all memory accesses made by these work items prior to the op are + visible to all work items in the workgroup. Data hazards between work items + accessing the same memory can be avoided by synchronizing work items + in-between these accesses. + + Either none or all work items of a workgroup need to execute this op + in convergence. + }]; + let parser = [{ return success(); }]; + let printer = [{ p << getOperationName(); }]; +} + +#endif // GPU_OPS diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..daf6d28d4526a8cb1d698b06970ab85634b05c8f --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -0,0 +1,27 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_GPU_PASSES_H_ +#define MLIR_DIALECT_GPU_PASSES_H_ + +#include + +namespace mlir { + +class ModuleOp; +template class OpPassBase; + +std::unique_ptr> createGpuKernelOutliningPass(); + +} // namespace mlir + +#endif // MLIR_DIALECT_GPU_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa68eff91b0c52e7658331e369731b0261445fc3 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,19 @@ +set(LLVM_TARGET_DEFINITIONS LLVMOps.td) +mlir_tablegen(LLVMOps.h.inc -gen-op-decls) +mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) +mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRLLVMOpsIncGen) + +add_mlir_dialect(NVVMOps NVVMOps) +add_mlir_dialect(ROCDLOps ROCDLOps) + +set(LLVM_TARGET_DEFINITIONS LLVMOps.td) +mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMConversionsIncGen) +set(LLVM_TARGET_DEFINITIONS NVVMOps.td) +mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRNVVMConversionsIncGen) +set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) +mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRROCDLConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..d36619bb9a9515a50f94df323e441b0ddef00a58 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -0,0 +1,199 @@ +//===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the LLVM IR dialect in MLIR, containing LLVM operations and +// LLVM type system. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" + +#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc" + +namespace llvm { +class Type; +class LLVMContext; +} // end namespace llvm + +namespace mlir { +namespace LLVM { +class LLVMDialect; + +namespace detail { +struct LLVMTypeStorage; +struct LLVMDialectImpl; +} // namespace detail + +class LLVMType : public mlir::Type::TypeBase { +public: + enum Kind { + LLVM_TYPE = FIRST_LLVM_TYPE, + }; + + using Base::Base; + + static bool kindof(unsigned kind) { return kind == LLVM_TYPE; } + + LLVMDialect &getDialect(); + llvm::Type *getUnderlyingType() const; + + /// Utilities to identify types. + bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } + bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); } + bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } + bool isIntegerTy(unsigned bitwidth) { + return getUnderlyingType()->isIntegerTy(bitwidth); + } + + /// Array type utilities. + LLVMType getArrayElementType(); + unsigned getArrayNumElements(); + bool isArrayTy(); + + /// Vector type utilities. + LLVMType getVectorElementType(); + bool isVectorTy(); + + /// Function type utilities. + LLVMType getFunctionParamType(unsigned argIdx); + unsigned getFunctionNumParams(); + LLVMType getFunctionResultType(); + bool isFunctionTy(); + + /// Pointer type utilities. + LLVMType getPointerTo(unsigned addrSpace = 0); + LLVMType getPointerElementTy(); + bool isPointerTy(); + + /// Struct type utilities. + LLVMType getStructElementType(unsigned i); + unsigned getStructNumElements(); + bool isStructTy(); + + /// Utilities used to generate floating point types. + static LLVMType getDoubleTy(LLVMDialect *dialect); + static LLVMType getFloatTy(LLVMDialect *dialect); + static LLVMType getHalfTy(LLVMDialect *dialect); + static LLVMType getFP128Ty(LLVMDialect *dialect); + static LLVMType getX86_FP80Ty(LLVMDialect *dialect); + + /// Utilities used to generate integer types. + static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits); + static LLVMType getInt1Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/1); + } + static LLVMType getInt8Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/8); + } + static LLVMType getInt8PtrTy(LLVMDialect *dialect) { + return getInt8Ty(dialect).getPointerTo(); + } + static LLVMType getInt16Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/16); + } + static LLVMType getInt32Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/32); + } + static LLVMType getInt64Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/64); + } + + /// Utilities used to generate other miscellaneous types. + static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements); + static LLVMType getFunctionTy(LLVMType result, ArrayRef params, + bool isVarArg); + static LLVMType getFunctionTy(LLVMType result, bool isVarArg) { + return getFunctionTy(result, llvm::None, isVarArg); + } + static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef elements, + bool isPacked = false); + static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) { + return getStructTy(dialect, llvm::None, isPacked); + } + template + static typename std::enable_if::value, + LLVMType>::type + getStructTy(LLVMType elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + return getStructTy(&elt1.getDialect(), fields); + } + static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); + static LLVMType getVoidTy(LLVMDialect *dialect); + +private: + friend LLVMDialect; + + /// Get an LLVMType with a pre-existing llvm type. + static LLVMType get(MLIRContext *context, llvm::Type *llvmType); + + /// Get an LLVMType with an llvm type that may cause changes to the underlying + /// llvm context when constructed. + static LLVMType getLocked(LLVMDialect *dialect, + function_ref typeBuilder); +}; + +///// Ops ///// +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" + +class LLVMDialect : public Dialect { +public: + explicit LLVMDialect(MLIRContext *context); + ~LLVMDialect(); + static StringRef getDialectNamespace() { return "llvm"; } + + llvm::LLVMContext &getLLVMContext(); + llvm::Module &getLLVMModule(); + + /// Parse a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + /// Print a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; + + /// Verify a region argument attribute registered to this dialect. + /// Returns failure if the verification failed, success otherwise. + LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx, + unsigned argIdx, + NamedAttribute argAttr) override; + +private: + friend LLVMType; + + std::unique_ptr impl; +}; + +/// Create an LLVM global containing the string "value" at the module containing +/// surrounding the insertion point of builder. Obtain the address of that +/// global and use it to compute the address of the first character in the +/// string (operations inserted at the builder insertion point). +Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, + StringRef value, LLVM::Linkage linkage, + LLVM::LLVMDialect *llvmDialect); + +/// LLVM requires some operations to be inside of a Module operation. This +/// function confirms that the Operation has the desired properties. +bool satisfiesLLVMModule(Operation *op); + +} // end namespace LLVM +} // end namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td new file mode 100644 index 0000000000000000000000000000000000000000..ed935d5b7f7829cd524e9943f0b2946bd952d5af --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -0,0 +1,52 @@ +//===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains shared definitions for the LLVM IR dialect and its +// subdialects. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_OP_BASE +#define LLVMIR_OP_BASE + +include "mlir/IR/OpBase.td" + +def LLVM_Dialect : Dialect { + let name = "llvm"; + let cppNamespace = "LLVM"; +} + +// LLVM IR type wrapped in MLIR. +def LLVM_Type : Type()">, + "LLVM dialect type">; + +// Type constraint accepting only wrapped LLVM integer types. +def LLVMInt : TypeConstraint< + And<[LLVM_Type.predicate, + CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, + "LLVM dialect integer">; + +// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder +// used to translate to LLVM IR proper. +class LLVM_OpBase traits = []> : + Op { + // A pattern for constructing the LLVM IR Instruction (or other Value) that + // corresponds to this op. This pattern can use `builder` to refer to an + // `llvm::IRBuilder<>` instance, $-names of arguments and results and the + // following special variable names: + // - $_resultType - substituted with the LLVM IR type of the result; + // - $_numOperands - substituted with the number of operands (including + // the variadic ones); + // - $_hasResult - substituted with a check that a variadic-result op does + // have a result (LLVM ops can have 0 or 1 result); + // - $_location - mlir::Location object of the instruction. + // Additionally, `$$` can be used to produce the dollar character. + string llvmBuilder = ""; +} + +#endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td new file mode 100644 index 0000000000000000000000000000000000000000..2e47eb034747d31d2ce888f2d7ba1ae77e2ce548 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -0,0 +1,734 @@ +//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the LLVM IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_OPS +#define LLVMIR_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +// Base class for LLVM operations. All operations get an "llvm." prefix in +// their name automatically. LLVM operations have either zero or one result, +// this class is specialized below for both cases and should not be used +// directly. +class LLVM_Op traits = []> : + LLVM_OpBase { +} + +class LLVM_Builder { + string llvmBuilder = builder; +} + +def LLVM_OneResultOpBuilder : OpBuilder< + "Builder *, OperationState &result, Type resultType, " + "ValueRange operands, ArrayRef attributes = {}", + [{ + if (resultType) result.addTypes(resultType); + result.addOperands(operands); + for (auto namedAttr : attributes) { + result.addAttribute(namedAttr.first, namedAttr.second); + } + }]>; + +def LLVM_ZeroResultOpBuilder : OpBuilder< + "Builder *, OperationState &result, ValueRange operands, " + "ArrayRef attributes = {}", + [{ + result.addOperands(operands); + for (auto namedAttr : attributes) { + result.addAttribute(namedAttr.first, namedAttr.second); + } + }]>; + +class LLVM_TwoBuilders { + list builders = [b1, b2]; +} + +// Base class for LLVM operations with one result. +class LLVM_OneResultOp traits = []> : + LLVM_Op, Results<(outs LLVM_Type:$res)> { + let builders = [LLVM_OneResultOpBuilder]; +} + +// Compatibility builder that takes an instance of wrapped llvm::VoidType +// to indicate no result. +def LLVM_VoidResultTypeOpBuilder : OpBuilder< + "Builder *builder, OperationState &result, Type resultType, " + "ValueRange operands, ArrayRef attributes = {}", + [{ + auto llvmType = resultType.dyn_cast(); (void)llvmType; + assert(llvmType && "result must be an LLVM type"); + assert(llvmType.getUnderlyingType() && + llvmType.getUnderlyingType()->isVoidTy() && + "for zero-result operands, only 'void' is accepted as result type"); + build(builder, result, operands, attributes); + }]>; + +// Base class for LLVM operations with zero results. +class LLVM_ZeroResultOp traits = []> : + LLVM_Op, Results<(outs)>, + LLVM_TwoBuilders; + +// Base class for LLVM terminator operations. All terminator operations have +// zero results and an optional list of successors. +class LLVM_TerminatorOp traits = []> : + LLVM_Op, + Arguments<(ins Variadic:$args)>, Results<(outs)> { + let builders = [ + OpBuilder< + "Builder *, OperationState &result, " + "ValueRange properOperands, " + "ArrayRef destinations, " + "ArrayRef operands, " + "ArrayRef attributes = {}", + [{ + result.addOperands(properOperands); + for (auto kvp : llvm::zip(destinations, operands)) { + result.addSuccessor(std::get<0>(kvp), std::get<1>(kvp)); + } + for (auto namedAttr : attributes) { + result.addAttribute(namedAttr.first, namedAttr.second); + } + }] + >, + OpBuilder< + "Builder *builder, OperationState &result, " + "ValueRange properOperands, " + "ArrayRef destinations, " + "ArrayRef attributes = {}", + [{ + SmallVector operands(destinations.size(), {}); + build(builder, result, properOperands, + destinations, operands, attributes); + }] + >, + ]; +} + +// Class for arithmetic binary operations. +class LLVM_ArithmeticOp traits = []> : + LLVM_OneResultOp, + Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>, + LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { + let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; +} +class LLVM_UnaryArithmeticOp traits = []> : + LLVM_OneResultOp, + Arguments<(ins LLVM_Type:$operand)>, + LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { + let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; +} + +// Integer binary operations. +def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>; +def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">; +def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>; +def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">; +def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">; +def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">; +def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">; +def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">; +def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">; +def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">; +def LLVM_ShlOp : LLVM_ArithmeticOp<"shl", "CreateShl">; +def LLVM_LShrOp : LLVM_ArithmeticOp<"lshr", "CreateLShr">; +def LLVM_AShrOp : LLVM_ArithmeticOp<"ashr", "CreateAShr">; + +// Predicate for integer comparisons. +def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; +def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>; +def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>; +def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>; +def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>; +def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>; +def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>; +def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>; +def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>; +def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>; +def ICmpPredicate : I64EnumAttr< + "ICmpPredicate", + "llvm.icmp comparison predicate", + [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, + ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, + ICmpPredicateUGT, ICmpPredicateUGE]> { + let cppNamespace = "::mlir::LLVM"; +} + +// Other integer operations. +def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, + Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs, + LLVM_Type:$rhs)> { + let llvmBuilder = [{ + $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState &result, ICmpPredicate predicate, Value lhs, " + "Value rhs", [{ + LLVMDialect *dialect = &lhs->getType().cast().getDialect(); + build(b, result, LLVMType::getInt1Ty(dialect), + b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); + }]>]; + let parser = [{ return parseCmpOp(parser, result); }]; + let printer = [{ printICmpOp(p, *this); }]; +} + +// Predicate for float comparisons +def FCmpPredicateFALSE : I64EnumAttrCase<"_false", 0>; +def FCmpPredicateOEQ : I64EnumAttrCase<"oeq", 1>; +def FCmpPredicateOGT : I64EnumAttrCase<"ogt", 2>; +def FCmpPredicateOGE : I64EnumAttrCase<"oge", 3>; +def FCmpPredicateOLT : I64EnumAttrCase<"olt", 4>; +def FCmpPredicateOLE : I64EnumAttrCase<"ole", 5>; +def FCmpPredicateONE : I64EnumAttrCase<"one", 6>; +def FCmpPredicateORD : I64EnumAttrCase<"ord", 7>; +def FCmpPredicateUEQ : I64EnumAttrCase<"ueq", 8>; +def FCmpPredicateUGT : I64EnumAttrCase<"ugt", 9>; +def FCmpPredicateUGE : I64EnumAttrCase<"uge", 10>; +def FCmpPredicateULT : I64EnumAttrCase<"ult", 11>; +def FCmpPredicateULE : I64EnumAttrCase<"ule", 12>; +def FCmpPredicateUNE : I64EnumAttrCase<"une", 13>; +def FCmpPredicateUNO : I64EnumAttrCase<"uno", 14>; +def FCmpPredicateTRUE : I64EnumAttrCase<"_true", 15>; + +def FCmpPredicate : I64EnumAttr< + "FCmpPredicate", + "llvm.fcmp comparison predicate", + [FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE, + FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD, + FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT, + FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE + ]> { + let cppNamespace = "::mlir::LLVM"; +} + +// Other integer operations. +def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, + Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs, + LLVM_Type:$rhs)> { + let llvmBuilder = [{ + $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState &result, FCmpPredicate predicate, Value lhs, " + "Value rhs", [{ + LLVMDialect *dialect = &lhs->getType().cast().getDialect(); + build(b, result, LLVMType::getInt1Ty(dialect), + b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); + }]>]; + let parser = [{ return parseCmpOp(parser, result); }]; + let printer = [{ printFCmpOp(p, *this); }]; +} + +// Floating point binary operations. +def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">; +def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">; +def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">; +def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">; +def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; +def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">; + +// Memory-related operations. +def LLVM_AllocaOp : + LLVM_OneResultOp<"alloca">, + Arguments<(ins LLVM_Type:$arraySize, OptionalAttr:$alignment)> { + string llvmBuilder = [{ + auto *alloca = builder.CreateAlloca( + $_resultType->getPointerElementType(), $arraySize); + if ($alignment.hasValue()) { + auto align = $alignment.getValue().getZExtValue(); + if (align != 0) + alloca->setAlignment(llvm::MaybeAlign(align)); + } + $res = alloca; + }]; + let builders = [OpBuilder< + "Builder *b, OperationState &result, Type resultType, Value arraySize, " + "unsigned alignment", + [{ + if (alignment == 0) + return build(b, result, resultType, arraySize, IntegerAttr()); + build(b, result, resultType, arraySize, b->getI64IntegerAttr(alignment)); + }]>]; + let parser = [{ return parseAllocaOp(parser, result); }]; + let printer = [{ printAllocaOp(p, *this); }]; + let verifier = [{ + if (alignment().hasValue()) { + auto align = alignment().getValue().getSExtValue(); + if (align < 0) + return emitOpError("expected positive alignment"); + } + return success(); + }]; +} +def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$base, Variadic:$indices)>, + LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> { + let parser = [{ return parseGEPOp(parser, result); }]; + let printer = [{ printGEPOp(p, *this); }]; +} +def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, + LLVM_Builder<"$res = builder.CreateLoad($addr);"> { + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value addr", + [{ + auto type = addr->getType().cast().getPointerElementTy(); + build(b, result, type, addr); + }]>]; + let parser = [{ return parseLoadOp(parser, result); }]; + let printer = [{ printLoadOp(p, *this); }]; +} +def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, + Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>, + LLVM_Builder<"builder.CreateStore($value, $addr);"> { + let parser = [{ return parseStoreOp(parser, result); }]; + let printer = [{ printStoreOp(p, *this); }]; +} + +// Casts. +class LLVM_CastOp traits = []> : + LLVM_OneResultOp, + Arguments<(ins LLVM_Type:$arg)>, + LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { + let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; + let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; +} +def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">; +def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast">; +def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">; +def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">; +def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">; +def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">; +def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">; +def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">; +def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP">; +def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI">; +def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI">; +def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">; +def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">; + +// Call-related operations. +def LLVM_CallOp : LLVM_Op<"call">, + Arguments<(ins OptionalAttr:$callee, + Variadic)>, + Results<(outs Variadic)>, + LLVM_TwoBuilders { + let verifier = [{ + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); + return success(); + }]; + let parser = [{ return parseCallOp(parser, result); }]; + let printer = [{ printCallOp(p, *this); }]; +} +def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$vector, + LLVM_Type:$position)> { + string llvmBuilder = [{ + $res = builder.CreateExtractElement($vector, $position); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value vector, Value position," + "ArrayRef attrs = {}">]; + let parser = [{ return parseExtractElementOp(parser, result); }]; + let printer = [{ printExtractElementOp(p, *this); }]; +} +def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$container, + ArrayAttr:$position)> { + string llvmBuilder = [{ + $res = builder.CreateExtractValue($container, extractPosition($position)); + }]; + let parser = [{ return parseExtractValueOp(parser, result); }]; + let printer = [{ printExtractValueOp(p, *this); }]; +} +def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value, + LLVM_Type:$position)> { + string llvmBuilder = [{ + $res = builder.CreateInsertElement($vector, $value, $position); + }]; + let parser = [{ return parseInsertElementOp(parser, result); }]; + let printer = [{ printInsertElementOp(p, *this); }]; +} +def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$container, LLVM_Type:$value, + ArrayAttr:$position)> { + string llvmBuilder = [{ + $res = builder.CreateInsertValue($container, $value, + extractPosition($position)); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value container, Value value, " + "ArrayAttr position", + [{ + build(b, result, container->getType(), container, value, position); + }]>]; + let parser = [{ return parseInsertValueOp(parser, result); }]; + let printer = [{ printInsertValueOp(p, *this); }]; +} +def LLVM_ShuffleVectorOp + : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, ArrayAttr:$mask)>, + LLVM_Builder< + "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> { + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value v1, Value v2, " + "ArrayAttr mask, ArrayRef attrs = {}">]; + let verifier = [{ + auto wrappedVectorType1 = v1()->getType().cast(); + auto wrappedVectorType2 = v2()->getType().cast(); + if (!wrappedVectorType2.getUnderlyingType()->isVectorTy()) + return emitOpError("expected LLVM IR Dialect vector type for operand #2"); + if (wrappedVectorType1.getVectorElementType() != + wrappedVectorType2.getVectorElementType()) + return emitOpError("expected matching LLVM IR Dialect element types"); + return success(); + }]; + let parser = [{ return parseShuffleVectorOp(parser, result); }]; + let printer = [{ printShuffleVectorOp(p, *this); }]; +} + +// Misc operations. +def LLVM_SelectOp + : LLVM_OneResultOp<"select", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue, + LLVM_Type:$falseValue)>, + LLVM_Builder< + "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value condition, Value lhs, " + "Value rhs", [{ + build(b, result, lhs->getType(), condition, lhs, rhs); + }]>]; + let parser = [{ return parseSelectOp(parser, result); }]; + let printer = [{ printSelectOp(p, *this); }]; +} + +// Terminators. +def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { + let parser = [{ return parseBrOp(parser, result); }]; + let printer = [{ printBrOp(p, *this); }]; +} +def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { + let verifier = [{ + if (getNumSuccessors() != 2) + return emitOpError("expected exactly two successors"); + return success(); + }]; + let parser = [{ return parseCondBrOp(parser, result); }]; + let printer = [{ printCondBrOp(p, *this); }]; +} +def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> { + string llvmBuilder = [{ + if ($_numOperands != 0) + builder.CreateRet($args[0]); + else + builder.CreateRetVoid(); + }]; + + let verifier = [{ + if (getNumOperands() > 1) + return emitOpError("expects at most 1 operand"); + return success(); + }]; + + let parser = [{ return parseReturnOp(parser, result); }]; + let printer = [{ printReturnOp(p, *this); }]; +} +def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { + string llvmBuilder = [{ builder.CreateUnreachable(); }]; + let parser = [{ return success(); }]; + let printer = [{ p << getOperationName(); }]; +} + +//////////////////////////////////////////////////////////////////////////////// +// Auxiliary operations (do not appear in LLVM IR but necessary for the dialect +// to work correctly). +//////////////////////////////////////////////////////////////////////////////// + +// Linkage attribute is used on functions and globals. The order follows that of +// https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to +// visible names in the IR rather than to enum values names in llvm::GlobalValue +// since the latter is easier to change. +def LinkagePrivate : I64EnumAttrCase<"Private", 0>; +def LinkageInternal : I64EnumAttrCase<"Internal", 1>; +def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>; +def LinkageLinkonce : I64EnumAttrCase<"Linkonce", 3>; +def LinkageWeak : I64EnumAttrCase<"Weak", 4>; +def LinkageCommon : I64EnumAttrCase<"Common", 5>; +def LinkageAppending : I64EnumAttrCase<"Appending", 6>; +def LinkageExternWeak : I64EnumAttrCase<"ExternWeak", 7>; +def LinkageLinkonceODR : I64EnumAttrCase<"LinkonceODR", 8>; +def LinkageWeakODR : I64EnumAttrCase<"WeakODR", 9>; +def LinkageExternal : I64EnumAttrCase<"External", 10>; +def Linkage : I64EnumAttr< + "Linkage", + "LLVM linkage types", + [LinkagePrivate, LinkageInternal, LinkageAvailableExternally, + LinkageLinkonce, LinkageWeak, LinkageCommon, LinkageAppending, + LinkageExternWeak, LinkageLinkonceODR, LinkageWeakODR, LinkageExternal]> { + let cppNamespace = "::mlir::LLVM"; +} + + +def LLVM_AddressOfOp + : LLVM_OneResultOp<"mlir.addressof">, + Arguments<(ins FlatSymbolRefAttr:$global_name)> { + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, LLVMType resType, " + "StringRef name, ArrayRef attrs = {}", [{ + result.addAttribute("global_name", builder->getSymbolRefAttr(name)); + result.addAttributes(attrs); + result.addTypes(resType);}]>, + + OpBuilder<"Builder *builder, OperationState &result, GlobalOp global, " + "ArrayRef attrs = {}", [{ + build(builder, result, + global.getType().getPointerTo(global.addr_space().getZExtValue()), + global.sym_name(), attrs);}]> + ]; + + let extraClassDeclaration = [{ + /// Return the llvm.mlir.global operation that defined the value referenced + /// here. + GlobalOp getGlobal(); + }]; + + let printer = "printAddressOfOp(p, *this);"; + let parser = "return parseAddressOfOp(parser, result);"; + let verifier = "return ::verify(*this);"; +} + +def LLVM_GlobalOp + : LLVM_ZeroResultOp<"mlir.global", + [IsolatedFromAbove, + SingleBlockImplicitTerminator<"ReturnOp">, Symbol]>, + Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name, + Linkage:$linkage, + OptionalAttr:$value, + DefaultValuedAttr:$addr_space)> { + let summary = "LLVM dialect global."; + let description = [{ + Can contain an optional initializer region or attribute for simple + initializers. + + Examples: + // Initialized using an attribute. + llvm.mlir.global @a("abc") : !llvm<"[3 x i8]"> + // Initialized using a region. + llvm.mlir.global constant @b() : !llvm<"i32*"> { + %0 = llvm.constant(0 : i32) : !llvm.i32 + %1 = llvm.inttoptr %0 : !llvm.i32 to !llvm<"i32*"> + llvm.return %1 : !llvm<"i32*"> + } + }]; + let regions = (region AnyRegion:$initializer); + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, LLVMType type, " + "bool isConstant, Linkage linkage, StringRef name, " + "Attribute value, unsigned addrSpace = 0, " + "ArrayRef attrs = {}"> + ]; + + let extraClassDeclaration = [{ + /// Return the LLVM type of the global. + LLVMType getType() { + return type().cast(); + } + /// Return the initializer attribute if it exists, or a null attribute. + Attribute getValueOrNull() { + return value().getValueOr(Attribute()); + } + /// Return the initializer region. This may be empty, but if it is not it + /// terminates in an `llvm.return` op with the initializer value. + Region &getInitializerRegion() { + return getOperation()->getRegion(0); + } + /// Return the initializer block. If the initializer region is empty this + /// is nullptr. If it is not nullptr, it terminates with an `llvm.return` + /// op with the initializer value. + Block *getInitializerBlock() { + return getInitializerRegion().empty() ? + nullptr : &getInitializerRegion().front(); + } + }]; + + let printer = "printGlobalOp(p, *this);"; + let parser = "return parseGlobalOp(parser, result);"; + let verifier = "return ::verify(*this);"; +} + +def LLVM_LLVMFuncOp + : LLVM_ZeroResultOp<"func", [IsolatedFromAbove, FunctionLike, Symbol]>, + Arguments<(ins DefaultValuedAttr:$linkage)> { + let summary = "LLVM dialect function, has wrapped LLVM IR function type"; + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, StringRef name, " + "LLVMType type, LLVM::Linkage linkage = LLVM::Linkage::External, " + "ArrayRef attrs = {}, " + "ArrayRef argAttrs = {}"> + ]; + + let extraClassDeclaration = [{ + // Add an entry block to an empty function, and set up the block arguments + // to match the signature of the function. + Block *addEntryBlock(); + + LLVMType getType() { + return getAttrOfType(getTypeAttrName()) + .getValue().cast(); + } + bool isVarArg() { + return getType().getUnderlyingType()->isFunctionVarArg(); + } + + // Hook for OpTrait::FunctionLike, returns the number of function arguments. + // Depends on the type attribute being correct as checked by verifyType. + unsigned getNumFuncArguments(); + + // Hook for OpTrait::FunctionLike, returns the number of function results. + // Depends on the type attribute being correct as checked by verifyType. + unsigned getNumFuncResults(); + + // Hook for OpTrait::FunctionLike, called after verifying that the 'type' + // attribute is present. This can check for preconditions of the + // getNumArguments hook not failing. + LogicalResult verifyType(); + }]; + + let verifier = [{ return ::verify(*this); }]; + let printer = [{ printLLVMFuncOp(p, *this); }]; + let parser = [{ return parseLLVMFuncOp(parser, result); }]; +} + +def LLVM_NullOp + : LLVM_OneResultOp<"mlir.null", [NoSideEffect]>, + LLVM_Builder<"$res = llvm::ConstantPointerNull::get(" + " cast($_resultType));"> { + let parser = [{ return parseNullOp(parser, result); }]; + let printer = [{ printNullOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; +} + +def LLVM_UndefOp : LLVM_OneResultOp<"mlir.undef", [NoSideEffect]>, + LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> { + let parser = [{ return parseUndefOp(parser, result); }]; + let printer = [{ printUndefOp(p, *this); }]; +} +def LLVM_ConstantOp + : LLVM_OneResultOp<"mlir.constant", [NoSideEffect]>, + Arguments<(ins AnyAttr:$value)>, + LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);"> +{ + let parser = [{ return parseConstantOp(parser, result); }]; + let printer = [{ printConstantOp(p, *this); }]; +} + +// Operations that correspond to LLVM intrinsics. With MLIR operation set being +// extendable, there is no reason to introduce a hard boundary between "core" +// operations and intrinsics. However, we systematically prefix them with +// "intr." to avoid potential name clashes. + +class LLVM_UnaryIntrinsicOp traits = []> : + LLVM_OneResultOp<"intr." # func, + !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>, + Arguments<(ins LLVM_Type:$in)>, + LLVM_Builder<"$res = builder.CreateCall(llvm::Intrinsic::getDeclaration(" + "builder.GetInsertBlock()->getModule(), llvm::Intrinsic::" # func # "," + "{$in->getType()}), {$in});"> { +} + +class LLVM_BinaryIntrinsicOp traits = []> : + LLVM_OneResultOp<"intr." # func, + !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>, + Arguments<(ins LLVM_Type:$a, LLVM_Type:$b)>, + LLVM_Builder<"$res = builder.CreateCall(llvm::Intrinsic::getDeclaration(" + "builder.GetInsertBlock()->getModule(), llvm::Intrinsic::" # func # "," + "{$a->getType(), $b->getType()}), {$a, $b});"> { +} + +class LLVM_TernaryIntrinsicOp traits = []> : + LLVM_OneResultOp<"intr." # func, + !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>, + Arguments<(ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c)>, + LLVM_Builder<"$res = builder.CreateCall(llvm::Intrinsic::getDeclaration(" + "builder.GetInsertBlock()->getModule(), llvm::Intrinsic::" # func # "," + "{$a->getType(), $b->getType(), $c->getType()}), {$a, $b, $c});"> { +} + +def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">; +def LLVM_FAbsOp : LLVM_UnaryIntrinsicOp<"fabs">; +def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">; +def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">; +def LLVM_CopySignOp : LLVM_BinaryIntrinsicOp<"copysign">; +def LLVM_FMulAddOp : LLVM_TernaryIntrinsicOp<"fmuladd">; + +def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$in)>, + Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::log, {$in->getType()}); + $res = builder.CreateCall(fn, {$in}); + }]; +} + +def LLVM_Log10Op : LLVM_Op<"intr.log10", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$in)>, + Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::log10, {$in->getType()}); + $res = builder.CreateCall(fn, {$in}); + }]; +} + +def LLVM_Log2Op : LLVM_Op<"intr.log2", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$in)>, + Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::log2, {$in->getType()}); + $res = builder.CreateCall(fn, {$in}); + }]; +} + +def LLVM_Prefetch : LLVM_ZeroResultOp<"intr.prefetch">, + Arguments<(ins LLVM_Type:$addr, LLVM_Type:$rw, + LLVM_Type:$hint, LLVM_Type:$cache)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::prefetch, $addr->getType()); + builder.CreateCall(fn, {$addr, $rw, $hint, $cache}); + }]; +} + +#endif // LLVMIR_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..afb6d4ab6272834e2f0baab9885c70be971b233b --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -0,0 +1,36 @@ +//===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the NVVM IR dialect in MLIR, containing NVVM operations and +// NVVM specific extensions to the LLVM type system. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +namespace mlir { +namespace NVVM { + +///// Ops ///// +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/NVVMOps.h.inc" + +class NVVMDialect : public Dialect { +public: + explicit NVVMDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "nvvm"; } +}; + +} // namespace NVVM +} // namespace mlir + +#endif /* MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ */ diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td new file mode 100644 index 0000000000000000000000000000000000000000..f35b7798149247ba98abd8968e596d8308755962 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -0,0 +1,137 @@ +//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the NVVM IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef NVVMIR_OPS +#define NVVMIR_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// NVVM dialect definitions +//===----------------------------------------------------------------------===// + +def NVVM_Dialect : Dialect { + let name = "nvvm"; + let cppNamespace = "NVVM"; +} + +//===----------------------------------------------------------------------===// +// NVVM op definitions +//===----------------------------------------------------------------------===// + +class NVVM_Op traits = []> : + LLVM_OpBase { +} + +//===----------------------------------------------------------------------===// +// NVVM special register op definitions +//===----------------------------------------------------------------------===// + +class NVVM_SpecialRegisterOp traits = []> : + NVVM_Op, + Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { + string llvmBuilder = "$res = createIntrinsicCall(builder," + # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");"; + let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; +} + +//===----------------------------------------------------------------------===// +// Lane index and range +def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">; +def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">; + +//===----------------------------------------------------------------------===// +// Thread index and range +def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">; +def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">; +def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">; +def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">; +def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">; +def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">; + +//===----------------------------------------------------------------------===// +// Block index and range +def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">; +def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">; +def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">; +def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; +def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; +def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; + +//===----------------------------------------------------------------------===// +// NVVM synchronization op definitions +//===----------------------------------------------------------------------===// + +def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0); + }]; + let parser = [{ return success(); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; +} + +def NVVM_ShflBflyOp : + NVVM_Op<"shfl.sync.bfly">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$dst, + LLVM_Type:$val, + LLVM_Type:$offset, + LLVM_Type:$mask_and_clamp, + OptionalAttr:$return_value_and_is_valid)> { + string llvmBuilder = [{ + auto intId = getShflBflyIntrinsicId( + $_resultType, static_cast($return_value_and_is_valid)); + $res = createIntrinsicCall(builder, + intId, {$dst, $val, $offset, $mask_and_clamp}); + }]; + let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let verifier = [{ + if (!getAttrOfType("return_value_and_is_valid")) + return success(); + auto type = getType().cast(); + if (!type.isStructTy() || type.getStructNumElements() != 2 || + !type.getStructElementType(1).isIntegerTy( + /*Bitwidth=*/1)) + return emitError("expected return type !llvm<\"{ ?, i1 }\">"); + return success(); + }]; +} + +def NVVM_VoteBallotOp : + NVVM_Op<"vote.ballot.sync">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> { + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, + llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred}); + }]; + let parser = [{ return parseNVVMVoteBallotOp(parser, result); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; +} + +def NVVM_MmaOp : + NVVM_Op<"mma.sync">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins Variadic:$args)> { + string llvmBuilder = [{ + $res = createIntrinsicCall( + builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args); + }]; + let parser = [{ return parseNVVMMmaOp(parser, result); }]; + let printer = [{ printNVVMMmaOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; +} + +#endif // NVVMIR_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..dab32d30e8f45dbd2e047b961821b5d3290c5b87 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -0,0 +1,45 @@ +//===- ROCDLDialect.h - MLIR ROCDL IR dialect -------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ROCDL dialect in MLIR, containing ROCDL operations +// and ROCDL specific extensions to the LLVM type system. +// +// Unfortunately there does not exists a formal definition of ROCDL IR that be +// pointed to here. However the following links contain more information about +// ROCDL (ROCm-Device-Library) +// +// https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/doc/OCML.md +// https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/doc/OCKL.md +// https://llvm.org/docs/AMDGPUUsage.html +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace ROCDL { + +///// Ops ///// +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc" + +class ROCDLDialect : public Dialect { +public: + explicit ROCDLDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "rocdl"; } +}; + +} // namespace ROCDL +} // namespace mlir + +#endif /* MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ */ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td new file mode 100644 index 0000000000000000000000000000000000000000..697ff9740a844b684e0d4e98b215aec9f9067ccb --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -0,0 +1,92 @@ +//===-- ROCDLOps.td - ROCDL IR dialect op definition file --*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the ROCDL IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef ROCDLIR_OPS +#define ROCDLIR_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// ROCDL dialect definitions +//===----------------------------------------------------------------------===// + +def ROCDL_Dialect : Dialect { + let name = "rocdl"; + let cppNamespace = "ROCDL"; +} + +//===----------------------------------------------------------------------===// +// ROCDL op definitions +//===----------------------------------------------------------------------===// + +class ROCDL_Op traits = []> : + LLVM_OpBase { +} + +//===----------------------------------------------------------------------===// +// ROCDL special register op definitions +//===----------------------------------------------------------------------===// + +class ROCDL_SpecialRegisterOp traits = []> : + ROCDL_Op, + Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { + string llvmBuilder = "$res = createIntrinsicCall(builder," + # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");"; + let parser = [{ return parseROCDLOp(parser, result); }]; + let printer = [{ printROCDLOp(p, this->getOperation()); }]; +} + +class ROCDL_DeviceFunctionOp traits = []> : + ROCDL_Op, + Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { + string llvmBuilder = "$res = createDeviceFunctionCall(builder, \"" + # device_function # "\", " # parameter # ");"; + let parser = [{ return parseROCDLOp(parser, result); }]; + let printer = [{ printROCDLOp(p, this->getOperation()); }]; +} + +//===----------------------------------------------------------------------===// +// Thread index and Block index + +def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">; +def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">; +def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">; + +def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">; +def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">; +def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">; + +//===----------------------------------------------------------------------===// +// Thread range and Block range + +def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x", + "__ockl_get_local_size", 0>; + +def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y", + "__ockl_get_local_size", 1>; + +def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z", + "__ockl_get_local_size", 2>; + +def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x", + "__ockl_get_global_size", 0>; + +def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y", + "__ockl_get_global_size", 1>; + +def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z", + "__ockl_get_global_size", 2>; + + +#endif // ROCDLIR_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5034e823ceb9a17ecabc16bf231c9f9885e647 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -0,0 +1,134 @@ +//===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ +#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class FuncOp; + +namespace linalg { + +class LinalgOp; + +/// A very primitive alias analysis which just records for each view, either: +/// 1. The base buffer, or +/// 2. The block argument view +/// that it indexes into. +/// This does not perform inter-block or inter-procedural analysis and assumes +/// that different block argument views do not alias. +class Aliases { +public: + /// Returns true if v1 and v2 alias. + bool alias(Value v1, Value v2) { return find(v1) == find(v2); } + +private: + /// Returns the base buffer or block argument into which the view `v` aliases. + /// This lazily records the new aliases discovered while walking back the + /// use-def chain. + Value find(Value v); + + DenseMap aliases; +}; + +/// Data structure for holding a dependence graph that operates on LinalgOp and +/// views as SSA values. +class LinalgDependenceGraph { +public: + struct LinalgOpView { + Operation *op; + Value view; + }; + struct LinalgDependenceGraphElem { + // dependentOpView may be either: + // 1. src in the case of dependencesIntoGraphs. + // 2. dst in the case of dependencesFromDstGraphs. + LinalgOpView dependentOpView; + // View in the op that is used to index in the graph: + // 1. src in the case of dependencesFromDstGraphs. + // 2. dst in the case of dependencesIntoGraphs. + Value indexingView; + }; + using LinalgDependences = SmallVector; + using DependenceGraph = DenseMap; + using dependence_iterator = LinalgDependences::const_iterator; + using dependence_range = iterator_range; + + enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes }; + + // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. + static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); + LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); + + /// Returns the X such that op -> X is a dependence of type dt. + dependence_range getDependencesFrom(Operation *src, DependenceType dt) const; + dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const; + + /// Returns the X such that X -> op is a dependence of type dt. + dependence_range getDependencesInto(Operation *dst, DependenceType dt) const; + dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const; + + /// Returns the operations that are interleaved between `srcLinalgOp` and + /// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence + /// relation with `srcLinalgOp`, on any view. + /// Any such operation prevents reordering. + SmallVector + findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const; + + /// Returns the operations that are interleaved between `srcLinalgOp` and + /// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`. + /// Dependences are restricted to views aliasing `view`. + SmallVector findCoveringReads(LinalgOp srcLinalgOp, + LinalgOp dstLinalgOp, + Value view) const; + + /// Returns the operations that are interleaved between `srcLinalgOp` and + /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`. + /// Dependences are restricted to views aliasing `view`. + SmallVector findCoveringWrites(LinalgOp srcLinalgOp, + LinalgOp dstLinalgOp, + Value view) const; + +private: + // Keep dependences in both directions, this is not just a performance gain + // but it also reduces usage errors. + // Dependence information is stored as a map of: + // (source operation -> LinalgDependenceGraphElem) + DependenceGraph dependencesFromGraphs[DependenceType::NumTypes]; + // Reverse dependence information is stored as a map of: + // (destination operation -> LinalgDependenceGraphElem) + DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes]; + + /// Analyses the aliasing views between `src` and `dst` and inserts the proper + /// dependences in the graph. + void addDependencesBetween(LinalgOp src, LinalgOp dst); + + // Adds an new dependence unit in the proper graph. + // Uses std::pair to keep operations and view together and avoid usage errors + // related to src/dst and producer/consumer terminology in the context of + // dependences. + void addDependenceElem(DependenceType dt, LinalgOpView indexingOpView, + LinalgOpView dependentOpView); + + /// Implementation detail for findCoveringxxx. + SmallVector + findOperationsWithCoveringDependences(LinalgOp srcLinalgOp, + LinalgOp dstLinalgOp, Value view, + ArrayRef types) const; + + Aliases &aliases; + SmallVector linalgOps; + DenseMap linalgOpPositions; +}; +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f57627c321fb0c74b3e4a404e3c36bd435f64a7 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h new file mode 100644 index 0000000000000000000000000000000000000000..97fbede1cc78771eefdc8548e0e7d17935a9107b --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -0,0 +1,229 @@ +//===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable interfaces for building structured MLIR +// snippets in a declarative fashion. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ +#define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" + +namespace mlir { +class BlockArgument; + +namespace edsc { +enum class IterType { Parallel, Reduction }; + +inline StringRef toString(IterType t) { + switch (t) { + case IterType::Parallel: + return getParallelIteratorTypeName(); + case IterType::Reduction: + return getReductionIteratorTypeName(); + default: + llvm_unreachable("Unsupport IterType"); + } +} + +/// A StructuredIndexed represents a captured value that can be indexed and +/// passed to the `makeLinalgGenericOp`. It allows writing intuitive index +/// expressions such as: +/// +/// ``` +/// StructuredIndexed A(vA), B(vB), C(vC); +/// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); +/// ``` +struct StructuredIndexed { + StructuredIndexed(Value v) : value(v) {} + StructuredIndexed operator()(ArrayRef indexings) { + return StructuredIndexed(value, indexings); + } + + operator Value() const /* implicit */ { return value; } + ArrayRef getExprs() { return exprs; } + +private: + StructuredIndexed(Value v, ArrayRef indexings) + : value(v), exprs(indexings.begin(), indexings.end()) { + assert(v->getType().isa() && "MemRefType expected"); + } + StructuredIndexed(ValueHandle v, ArrayRef indexings) + : StructuredIndexed(v.getValue(), indexings) {} + + Value value; + SmallVector exprs; +}; + +inline void defaultRegionBuilder(ArrayRef args) {} + +Operation *makeLinalgGenericOp( + ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef outputs, + function_ref)> regionBuilder = + defaultRegionBuilder, + ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); + +namespace ops { +using edsc::StructuredIndexed; +using edsc::ValueHandle; +using edsc::intrinsics::linalg_yield; + +//===----------------------------------------------------------------------===// +// EDSC builders for linalg generic operations. +//===----------------------------------------------------------------------===// + +/// Build the body of a region to compute a multiply-accumulate, under the +/// current ScopedContext, at the current insert point. +void macRegionBuilder(ArrayRef args); + +/// TODO(ntv): In the future we should tie these implementations to something in +/// Tablegen that generates the proper interfaces and the proper sugared named +/// ops. + +/// Build a linalg.pointwise, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (i0, ..., in) = (par, ..., par) +/// | +/// | O...(some_subset...(i0, ..., in)) = +/// | some_pointwise_func...(I...(some_other_subset...(i0, ..., in))) +/// ``` +/// +/// This is a very generic entry point that can be configured in many ways to +/// build a perfect loop nest of parallel loops with arbitrarily complex +/// innermost loop code and whatever (explicit) broadcast semantics. +/// +/// This can be used with both out-of-place and in-place semantics. +/// The client is responsible for ensuring the region operations are compatible +/// with in-place semantics and parallelism. + +/// Unary pointwise operation (with broadcast) entry point. +using UnaryPointwiseOpBuilder = function_ref; +Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, + StructuredIndexed I, StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = tanh(I)`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); + +/// Binary pointwise operation (with broadcast) entry point. +using BinaryPointwiseOpBuilder = function_ref; +Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, + StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = I1 + I2`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = max(I!, I2)`. The client is responsible for specifying the +/// proper indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +// TODO(ntv): Implement more useful pointwise operations on a per-need basis. + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (m, n, k) = (par, par, seq) +/// | +/// | C(m, n) += A(m, k) * B(k, n) +/// ``` +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); + +template Operation *linalg_matmul(Container values) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_matmul(values[0], values[1], values[2]); +} + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (batch, f, [h, w, ...], [kh, kw, ...], c) = +/// | (par, par, [par, par, ...], [red, red, ...], red) +/// | +/// | O(batch, [h, w, ...], f) += +/// | I(batch, +/// | [ +/// | stride[0] * h + dilations[0] * kh, +/// | stride[1] * w + dilations[1] * kw, ... +/// ], +/// | c) +/// | * +/// | W([kh, kw, ...], c, f) +/// ``` +/// If `dilations` or `strides` are left empty, the default value of `1` is used +/// along each relevant dimension. +/// +/// For now `...` must be empty (i.e. only 2-D convolutions are supported). +/// +// TODO(ntv) Extend convolution rank with some template magic. +Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef strides = {}, + ArrayRef dilations = {}); + +template +Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, + ArrayRef dilations = {}) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); +} + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = +/// | (par, par, par, [par, par, ...], [red, red, ...]) +/// | +/// | O(batch, [h, w, ...], c * depth_multiplier) += +/// | I(batch, +/// | [ +/// | stride[0] * h + dilations[0] * kh, +/// | stride[1] * w + dilations[1] * kw, ... +/// ], +/// | c) +/// | * +/// | W([kh, kw, ...], c, depth_multiplier) +/// ``` +/// If `dilations` or `strides` are left empty, the default value of `1` is used +/// along each relevant dimension. +/// +/// For now `...` must be empty (i.e. only 2-D convolutions are supported). +/// +// TODO(ntv) Extend convolution rank with some template magic. +Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, + ValueHandle vO, int depth_multiplier = 1, + ArrayRef strides = {}, + ArrayRef dilations = {}); + +template +Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, + ArrayRef strides = {}, + ArrayRef dilations = {}) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_dilated_conv_nhwc(values[0], values[1], values[2], + depth_multiplier, strides, dilations); +} + +} // namespace ops +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..b04c11f22bb9f1e919aec58e028ccb86d7cad93a --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -0,0 +1,26 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using linalg_fill = OperationBuilder; +using linalg_yield = OperationBuilder; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..269729bc644528e4573c3f0e8338570d55f8bd5c --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect(LinalgOps LinalgDoc) +set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td) +mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls) +mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs) +mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen) + diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td new file mode 100644 index 0000000000000000000000000000000000000000..c1adc8b4d05c908ae7eb351067635a233ef1f81b --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -0,0 +1,111 @@ +//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for base linear algebra support. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_BASE +#define LINALG_BASE + +include "mlir/IR/OpBase.td" + +def Linalg_Dialect : Dialect { + let name = "linalg"; + let description = [{ + The `linalg` dialect groups together a set of types, operations and + transformations that are useful to implement a structured abstraction where + ops can lower to scalar load/store and operations or to more general library + calls. + + The `linalg` dialect manipulates the following types and operations: + + ### Core data types and special ops. + + The following abstractions are used by the `linalg` dialect: + + #### Views + The current implementation uses the strided memref abstraction. In the + future other abstractions than strided memref will be used. + + #### `!linalg.range` + This data type is currently just a triple (`min`,`max`, `step`) that does + not pass function boundaries. + + #### `linalg.yield` + This op is used as a terminator within the appropriate `linalg` regions. + + In the future, richer `view` and `range` representations are expected, in + particular to represent sparse traversals. + + ### Metadata Ops + A set of ops that manipulate metadata but do not move memory. These ops take + `view` operands + extra attributes and return new `view`s. The returned + `view`s generally alias the operand `view`. At the moment the existing ops + are: + + * `std.view`, + * `std.subview`, + * `linalg.range`, + * `linalg.slice`, + * `linalg.transpose`. + + Future ops are added on a per-need basis but should include: + + * `linalg.reshape`, + * `linalg.tile`, + * `linalg.intersection`, + * `linalg.convex_union`, + * `linalg.difference` (would need to work on a list of views). + + ### Payload Ops + A set of payload carrying operations that implement the [structured ops]( + https://docs.google.com/presentation/d/1P-j1GrH6Q5gLBjao0afQ-GfvcAeF-QU4GXXeSy0eJ9I/edit#slide=id.p + ) + abstraction on buffers. `linalg` has `2` generic operations `linalg.generic` + and `linalg.indexed_generic` for expressing custom operations. This is + subject to further evolution as transformations and analyses continue to be + developed. + + Additionally, `linalg` provides some common named operations: + + * `linalg.copy`, + * `linalg.fill`, + * `linalg.dot`, + * `linalg.matmul`, + * `linalg.conv`. + + Future ops are added on a per-need basis but should include: + + * `linalg.pad`. + + In an ideal world, all the named ops would be automatically generated from + a description in terms of only the `2` generic ops. Unfortunately we do not + have such support yet (contributions are most welcome). + + ### Convention for external library interop + The `linalg` dialect adopts a convention that is similar to `BLAS` when + offloading operations to fast library implementations: pass a non-owning + pointer to input and output data with additional metadata. This convention + is also found in libraries such as `MKL`, `OpenBLAS`, `BLIS`, `cuBLAS`, + `cuDNN`, etc.. and more generally at interface points across language + boundaries (e.g. C++ / Python). + + Generally, `linalg` passes non-owning pointers to strided memref data + structures to precompiled library calls linked externally. The name `view` + is used interchangeably in `linalg` to signify strided memref discussed at + length in the [strided memref RFC]( + https://groups.google.com/a/tensorflow.org/g/mlir/c/MaL8m2nXuio/m/a_v07o9yBwAJ). + }]; +} + +// Whether a type is a RangeType. +def LinalgIsRangeTypePred : CPred<"$_self.isa()">; +def Range : Type; + +#endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td new file mode 100644 index 0000000000000000000000000000000000000000..819d02d396d4598f42a513eb335e7484b7253f43 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td @@ -0,0 +1,23 @@ +//===- LinalgDoc.td - Linalg documentation -----------------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This documentation files exists to circumvent limitations on mixing different +// .td files in cases one does not want to have all ops belong to the same +// logical unit. This file should only include other .td files only and be used +// for the purpose of generating documentation. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_DOC +#define LINALG_DOC + +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgOps.td" +include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" + +#endif // LINALG_DOC diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td new file mode 100644 index 0000000000000000000000000000000000000000..6fdb8a644af7dbdb9f209f7e17491c26daff80f0 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -0,0 +1,616 @@ +//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for linear algebra operations that +// correspond to underlying library calls (e.g. BLAS). +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_LIBRARY_OPS +#define LINALG_LIBRARY_OPS + +include "mlir/Dialect/AffineOps/AffineOpsBase.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" + +// The Linalg `NInputs` trait provides the API for ops that are known +// to have a specified number of inputs, all passed as operands. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NInputs : + NativeOpTrait<"linalg::NInputs<" # !cast(args_in) # ">::Impl"> {} + +// The Linalg `NOutputs` trait provides the API for ops that are known +// to have a specified number of outputs, all passed as operands. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NOutputs : + NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {} + +def ViewTraits : NativeOpTrait<"linalg::ViewTraits">; + +// The linalg 'LinalgLibraryInterface' provides access to the 'LinalgOp' +// interface. +def LinalgLibraryInterface : OpInterface<"LinalgOp"> { + let methods = [ + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod< + "Query the number of inputs and outputs from the current operation.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" + >, + InterfaceMethod< + "Query the output operands from the current operation.", + "Operation::operand_range", "getOutputs" + >, + InterfaceMethod< + "Query the input and output operands from the current operation.", + "Operation::operand_range", "getInputsAndOutputs" + >, + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod< + "Query the number of parallel loops within the current operation.", + "unsigned", "getNumParallelLoops" + >, + InterfaceMethod< + "Query the number of reduction loops within the current operation.", + "unsigned", "getNumReductionLoops" + >, + InterfaceMethod< + "Query the number of window loops within the current operation.", + "unsigned", "getNumWindowLoops" + >, + InterfaceMethod< + "Query the number of loops within the current operation.", + "unsigned", "getNumLoops">, + InterfaceMethod<"Query the input view at the given index.", + "Value ", "getInput", (ins "unsigned":$i) + >, + InterfaceMethod<"Query the output view at the given index.", + "Value ", "getOutput", (ins "unsigned":$i) + >, + InterfaceMethod<[{ + Query the index of the given input value, or `None` if the value is not + an input. + }], + "Optional", "getIndexOfInput", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the index of the given view value, or `None` if the value is not + an view. + }], + "Optional", "getIndexOfOutput", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the input view at the given index. + }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the type of the output view at the given index. + }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>, + + StaticInterfaceMethod<[{ + Create an operation of the current type with the given location, + operands, and attributes. + }], + "Operation *", "create", + (ins "OpBuilder &":$builder, "Location":$loc, + "ValueRange":$operands, + "ArrayRef":$attributes), [{ + return builder.create(loc, ArrayRef{}, operands, + attributes); + }] + >, + + /// Clone an operation with the given location and operands. This is used to + /// abstract away the optional underlying region creation. + InterfaceMethod<[{ + Clone the current operation with the given location and operands. This + is used to abstract away the optional underlying region creation. + }], + "Operation *", "clone", + (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ + BlockAndValueMapping map; + unsigned numRegions = op.getOperation()->getNumRegions(); + Operation *res = create(b, loc, operands, op.getAttrs()); + assert(res->getNumRegions() == numRegions && "inconsistent # regions"); + for (unsigned ridx = 0; ridx < numRegions; ++ridx) + op.getOperation()->getRegion(ridx).cloneInto( + &res->getRegion(ridx), map); + return res; + }] + > + ]; +} + +// Base Tablegen class for Linalg ops. +// Linalg ops that correspond to library calls operate on linalg::View as their +// first operands. These may be optionally followed by non-view operands +// depending on the specific Linalg op. +class LinalgLibraryBase_Op props> + : Op { + let parser = [{ return parseLinalgLibraryOp(parser, result); }]; + let printer = [{ printLinalgLibraryOp(p, *this); }]; +} + +class LinalgLibrary_Op props> + : LinalgLibraryBase_Op { + code libraryCallName = [{ + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); + } + }]; +} + +//////////////////////////////////////////////////////////////////////////////// +// Concrete Linalg ops. +//////////////////////////////////////////////////////////////////////////////// +def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> { + let description = [{ + Copies the data in the input view into the output view. + + Usage: + ```mlir + linalg.copy(%arg0, %arg1) : memref, + memref + ``` + + One possible lowering to loop form is: + ```mlir + %0 = linalg.dim %arg0, 0 : index + loop.for %i0 = %c0 to %0 step %c1 { + %1 = linalg.load %arg0[%i0] : memref + linalg.store %1, %arg1[%i0] : memref + } + ``` + + Optionally, can take `input_permutation` and `output_permutation` attributes + to reorder the dimensions of the input and output views. + + Usage: + ```mlir + linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j), + outputPermutation : (i, j, k) -> (k, j, i)} : + memref, + memref + ``` + + One possible lowering to loop form is: + ```mlir + %0 = linalg.dim %arg0, 0 + %1 = linalg.dim %arg0, 1 + %2 = linalg.dim %arg0, 2 + loop.for %i0 = %c0 to %{{.*}} step %c1 { + loop.for %i1 = %c0 to %{{.*}} step %c1 { + loop.for %i2 = %c0 to %{{.*}} step %c1 { + %3 = linalg.load %arg0[%i0, %i2, %i1] : + memref + linalg.store %3, %arg1[%i2, %i1, %i0] : + memref + ``` + + The views are expected to be compatible for correctness but this is not + enforced at the moment. + }]; + let arguments = (ins + AnyStridedMemRef:$input, + AnyStridedMemRef:$output, + OptionalAttr:$inputPermutation, + OptionalAttr:$outputPermutation); + // TODO(ntv) this should go away once the usage of OptionalAttr triggers + // emission of builders with default arguments left unspecified. + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value input, Value output", [{ + return build( + builder, result, input, output, AffineMapAttr(), AffineMapAttr()); + }]>]; + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + unsigned nPar = input()->getType().cast().getRank(); + MLIRContext *ctx = getContext(); + SmallVector iters( + nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + return ArrayAttr::get(iters, ctx); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} + +def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRef:$output, + AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + unsigned nPar = output()->getType().cast().getRank(); + MLIRContext *ctx = getContext(); + SmallVector iters( + nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + return ArrayAttr::get(iters, ctx); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} + +def DotOp : LinalgLibrary_Op<"dot", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<0>); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + MLIRContext *ctx = getContext(); + return ArrayAttr::get( + StringAttr::get(getReductionIteratorTypeName(), ctx), ctx); + } + }]; +} + +def MatvecOp : LinalgLibrary_Op<"matvec", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<1>); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + MLIRContext *ctx = getContext(); + Attribute iters[2]{ + StringAttr::get(getParallelIteratorTypeName(), ctx), + StringAttr::get(getReductionIteratorTypeName(), ctx)}; + return ArrayAttr::get(iters, ctx); + } + }]; +} + +def MatmulOp : LinalgLibrary_Op<"matmul", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<2>); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + MLIRContext *ctx = getContext(); + Attribute iters[3]{ + StringAttr::get(getParallelIteratorTypeName(), ctx), + StringAttr::get(getParallelIteratorTypeName(), ctx), + StringAttr::get(getReductionIteratorTypeName(), ctx)}; + return ArrayAttr::get(iters, ctx); + } + }]; +} + +def ConvOp : LinalgLibrary_Op<"conv", [NInputs<2>, NOutputs<1>]> { + let description = [{ + Generic n-D convolution as described in the TF documentation: + https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution + + ``` + output[b, x[0], ..., x[N-1], k] = + sum_{z[0], ..., z[N-1], q} + filter[z[0], ..., z[N-1], q, k] * + padded_input[b, + x[0] * strides[0] + dilation_rate[0] * z[0], + ..., + x[N-1] * strides[N-1] + dilation_rate[N-1] * z[N-1], + q] + ``` + }]; + + // TODO(ntv) padding. + // Following the TF source of truth above, strides and dilations are integer + // attributes of the same rank as the number of window dimensions. + let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input, + AnyStridedMemRef:$output, + OptionalAttr:$strides, + OptionalAttr:$dilations); + let extraClassDeclaration = libraryCallName # [{ + // TODO(ntv) extend to support more than 1 dimensions and potentially + // grouping too. + unsigned getNumBatchDimensions() { return 1; } + unsigned getNumInputFeatureDimensions() { return 1; } + unsigned getNumOutputFeatureDimensions() { return 1; } + + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + // Outer parallel loops are always the number of output dimensions; i.e. + // [ b, xs, q] in the TF notation above. + unsigned nPar = getOutputViewType(0).getRank(); + unsigned nRed = getNumInputFeatureDimensions(); + // Window loops are a special kind of reduction that is never tiled or + // parallelized across; i.e. [zs] in the TF notation above whose number + // match `xs` (i.e. 1 window loop per "image" dimension). + // This may evolve in the future. + unsigned nWin = + nPar - getNumBatchDimensions() - getNumInputFeatureDimensions(); + MLIRContext *ctx = getContext(); + SmallVector iters( + nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + iters.reserve(nPar + nRed + nWin); + iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx)); + iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx)); + return ArrayAttr::get(iters, ctx); + } + + int64_t getStride(unsigned i) { + assert(i < getNumWindowLoops()); + if (!strides().hasValue()) return 1; + return strides()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getDilation(unsigned i) { + assert(i < getNumWindowLoops()); + if (!dilations().hasValue()) return 1; + return dilations()->getValue()[i] + .cast().getValue().getSExtValue(); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} + +class GenericOpBase : LinalgLibraryBase_Op { + let arguments = (ins Variadic:$views, + I64Attr:$args_in, + I64Attr:$args_out, + AffineMapArrayAttr:$indexing_maps, + ArrayAttr:$iterator_types, + OptionalAttr:$doc, + OptionalAttr:$fun, + OptionalAttr:$library_call); + let regions = (region AnyRegion:$region); + let extraClassDeclaration = [{ + SmallVector linalgTraitAttrNames() { + return SmallVector{ + getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(), + getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(), + getIteratorTypesAttrName() + }; + } + unsigned getNumInputs() { return args_in().getSExtValue(); } + unsigned getNumOutputs() { return args_out().getSExtValue(); } + FuncOp getFunction() { + auto moduleOp = getParentOfType(); + return fun().hasValue() ? + moduleOp.lookupSymbol(fun().getValue()) : FuncOp(); + } + StringRef getLibraryCallName() { + return library_call().hasValue() ? library_call().getValue() : ""; + } + AffineMap getIndexingMap(unsigned i) { + assert(i < getNumInputsAndOutputs()); + return indexing_maps().getValue()[i].cast().getValue(); + } + AffineMap getInputIndexingMap(unsigned i) { + assert(i < getNumInputs()); + return indexing_maps().getValue()[i].cast().getValue(); + } + AffineMap getOutputIndexingMap(unsigned i) { + assert(i < getNumOutputs()); + return indexing_maps().getValue()[i + getNumInputs()] + .cast().getValue(); + } + }]; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseGenericOp(parser, result); }]; +} + +def GenericOp : GenericOpBase<"generic"> { + let description = [{ + Generic Linalg op form where the key properties of the computation are + specified as attributes. In pretty form, a linalg.generic op is written as: + + ```mlir + linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + memref, + memref, + memref + ``` + + Where #trait_attributes is an alias of a dictionary attribute containing: + - args_in: an I64Attr representing the number of input (readonly) views + - args_out: an I64Attr representing the number of output (readwrite) views + - doc [optional]: a documentation string + - fun: a FlatSymbolRefAttr that must resolve to an existing function + symbol. To support inplace updates in a generic fashion, the signature + of the function must be: + ``` + fun([input views element types], [output views element types]) + -> ([output views element types]) + ``` + - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input + and output view. Such AffineMapAttr specifies the mapping between the + loops and the indexing within each view. + - library_call [optional]: a StringAttr containing the name of an + external library function that the linalg.generic operation maps to. + The external library is assumed to be dynamically linked and no strong + compile-time guarantees are provided. In the absence of such a library + call, linalg.generic will always lower to loops. + - iterator_types: an ArrayAttr specifying the type of the enclosing loops. + Each element of the list represents and iterator of one of the following + types: + parallel, reduction, window + + Example: + Defining a #matmul_trait attribute in MLIR can be done as follows: + ```mlir + func @fma(%a: f32, %b: f32, %c: f32) -> f32 { + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + return %e: f32 + } + #matmul_accesses = [ + (m, n, k) -> (m, k), + (m, n, k) -> (k, n), + (m, n, k) -> (m, n) + ] + #matmul_trait = { + doc = "C(m, n) += A(m, k) * B(k, n)", + fun = @fma, + indexing_maps = #matmul_accesses, + library_call = "linalg_matmul", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "reduction"] + } + ``` + + And can be reused in multiple places as: + ```mlir + linalg.generic #matmul_trait %A, %B, %C [other-attributes] : + memref, + memref, + memref + ``` + + This may lower to either: + ```mlir + call @linalg_matmul(%A, %B, %C) : + (memref, + memref, + memref) + -> () + ``` + + or IR resembling: + ```mlir + loop.for %m = %c0 to %M step %c1 { + loop.for %n = %c0 to %N step %c1 { + loop.for %k = %c0 to %K step %c1 { + %a = linalg.load %A[%m, %k] : memref + %b = linalg.load %B[%k, %n] : memref + %c = linalg.load %C[%m, %n] : memref + %d = call @func_of_elements(%a, %b, %c) + : (f32, f32, f32) -> (f32) + linalg.store %d, %C[%m, %n] : memref + } + } + } + ``` + }]; + let verifier = [{ return ::verify(*this); }]; +} + +def IndexedGenericOp : GenericOpBase<"indexed_generic"> { + let description = [{ + Indexed Generic Linalg op form where the key properties of the computation + are specified as attributes. In pretty form, a linalg.indexed_generic op is + written as: + + ```mlir + linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} : + memref, + memref, + memref + ``` + + Where #trait_attributes is an alias of a dictionary attribute containing: + - args_in: an I64Attr representing the number of input (readonly) views + - args_out: an I64Attr representing the number of output (readwrite) views + - doc [optional]: a documentation string + - fun: a FlatSymbolRefAttr that must resolve to an existing function + symbol. To support inplace updates in a generic fashion, the signature + of the function must be: + ``` + fun([index types of induction variables], [input views element types], + [output views element types]) -> ([output views element types]) + ``` + - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input + and output view. Such AffineMapAttr specifies the mapping between the + loops and the indexing within each view. + - library_call [optional]: a StringAttr containing the name of an + external library function that the linalg.indexed_generic operation + maps to. The external library is assumed to be dynamically linked and + no strong compile-time guarantees are provided. In the absence of such + a library call, linalg.indexed_generic will always lower to loops. + - iterator_types: an ArrayAttr they type of the enclosing loops; Each + element of the list represents and iterator of one of the following + types: + parallel, reduction, window + + Example: + Defining a #matmul_trait attribute in MLIR can be done as follows: + ```mlir + func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) + -> f32 + { + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + return %e: f32 + } + #matmul_accesses = [ + (m, n, k) -> (m, k), + (m, n, k) -> (k, n), + (m, n, k) -> (m, n) + ] + #matmul_trait = { + doc = "C(m, n) += A(m, k) * B(k, n)", + fun = @fma, + indexing_maps = #matmul_accesses, + library_call = "linalg_matmul", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "reduction"] + } + ``` + + And can be reused in multiple places as: + ```mlir + linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] : + memref, + memref, + memref + ``` + + This may lower to either: + ```mlir + call @linalg_matmul(%A, %B, %C) : + (memref, + memref, + memref) + -> () + ``` + + or IR resembling: + ```mlir + loop.for %m = %c0 to %M step %c1 { + loop.for %n = %c0 to %N step %c1 { + loop.for %k = %c0 to %K step %c1 { + %a = linalg.load %A[%m, %k] : memref + %b = linalg.load %B[%k, %n] : memref + %c = linalg.load %C[%m, %n] : memref + %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c) + : (index, index, index, f32, f32, f32) -> (f32) + linalg.store %d, %C[%m, %n] : memref + } + } + } + ``` + }]; + let verifier = [{ return ::verify(*this); }]; +} + +#endif // LINALG_LIBRARY_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h new file mode 100644 index 0000000000000000000000000000000000000000..3249edb48e020f0aec277e3ad8ee766f972c0661 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -0,0 +1,83 @@ +//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_LINALGOPS_H_ +#define MLIR_DIALECT_LINALG_LINALGOPS_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgTraits.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace linalg { + +/// Returns the name mangled library call name to disambiguate between different +/// overloads at the C level. The name mangling scheme is basic and uses MLIR +/// type names: +/// 1. form a string which is the concatenation of the linalg op name with all +/// the operand type names, separate by underscores; +/// 2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type. +/// Assumes `op` is a LinalgOp. +/// +/// Examples: +/// +/// 1. linalg.fill(%A, %f) : memref, f32 +/// name mangles into `linalg_fill_viewf32_f32_impl` +/// +/// 2. linalg.dot(%A, %B, %C) : +/// memref, +/// memref, memref +/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl` +/// +/// 3. linalg.matmul(...) : +/// memref, +/// memref, +/// memref +/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl` +std::string generateLibraryCallName(Operation *op); + +/// Returns the list of maps that map loops to operands of a Linalg op. +/// The i-th affine map identifies loop indices to subscripts that are used when +/// accessing the i-th operand. +/// For instance, a matmul that can be written in index notation as: +/// `A(i, k) * B(k, j) -> C(i, j)` will have the following, ordered, list of +/// affine maps: +/// +/// ```mlir +/// ( +/// (i, j, k) -> (i, k), +/// (i, j, k) -> (k, j), +/// (i, j, k) -> (i, j) +/// ) +/// ``` +/// +/// Only permutation maps are currently supported. +SmallVector loopToOperandRangesMaps(Operation *op); + +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" + +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td new file mode 100644 index 0000000000000000000000000000000000000000..0445968ee809dec2bdc1f95d69b84a6d772d330c --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -0,0 +1,181 @@ +//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for linear algebra operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_OPS +#define LINALG_OPS + +include "mlir/Dialect/AffineOps/AffineOpsBase.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" + +// Base class for Linalg dialect ops that do not correspond to library calls. +class Linalg_Op traits = []> : + Op { + // For every linalg op, there needs to be a: + // * void print(OpAsmPrinter &p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, + // OperationState &result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +def Linalg_RangeOp : + Linalg_Op<"range", [NoSideEffect]>, + Arguments<(ins Index:$min, Index:$max, Index:$step)>, + Results<(outs Range)> { + let summary = "Create a `range` type value, used to create `view`s"; + let description = [{ + The `linalg.range` op creates a `!linalg.range` from 3 values of type + `index` that represent the min, max and step values of the `range`. This + type does not pass function boundaries at the moment. + + Example: + + ```mlir + %3 = linalg.range %0:%1:%2 : !linalg.range + ```` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value min, Value max, " + "Value step", + [{ + auto rangeType = RangeType::get(builder->getContext()); + build(builder, result, rangeType, min, max, step); + }]>]; + + // Fully specified by traits. + let verifier = ?; +} + +def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, + Arguments<(ins AnyStridedMemRef:$view, Variadic>:$indexings)>, + Results<(outs AnyStridedMemRef)> { + let summary = "Produce a rank-reduced `subview` of a base `view`."; + let description = [{ + The `linalg.slice` op allows defining a subregion of a smaller rank than the + operand `view` within the underlying buffer. + + A `linalg.slice` op takes a view and a variadic number of indexings and + produces a `view` of the same elemental type. An indexing is either: + 1. a `linalg.range`, in which case it does not reduce the rank of the + parent `view` along the corresponding dimension. + 2. an `index`, in which case it reduces the rank of the parent view by + one. + + If an indexing extends past the size of the `view`, this is undefined + behavior. Ideally the `linalg.slice` operation would automatically truncate + it to be within bounds but there are tradeoffs involved now that `std.view` + is a standard op. + + Examples: + + 1. rank-preserving `slice`: + + ```mlir + %4 = linalg.slice %0[%1, %2] : memref, + !linalg.range, !linalg.range, memref + ``` + + 2. rank-reducing `slice` (from 2-D to 1-D): + + ```mlir + %4 = linalg.slice %0[%1, %2] : memref, + index, !linalg.range, memref + ``` + + 3. rank-reducing `slice` (from 2-D to 0-D): + + ```mlir + %4 = linalg.slice %0[%1, %2] : memref, + index, index, memref + ``` + }]; + + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value base, " + "ValueRange indexings">]; + + let extraClassDeclaration = [{ + enum { FirstIndexingOperand = 1 }; + unsigned getRank() { return getViewType().getRank(); } + Type getElementType() { return getViewType().getElementType(); } + MemRefType getViewType() { return getType().cast(); } + unsigned getBaseViewRank() { return getBaseViewType().getRank(); } + MemRefType getBaseViewType() { return view()->getType().cast(); } + + // Get the underlying indexing at a given rank. + Value indexing(unsigned rank) { return *(indexings().begin() + rank); } + + // Get the subset of indexings that are of RangeType. + SmallVector getRanges() { + SmallVector res; + for (auto operand : indexings()) + if (!operand->getType().isa()) + res.push_back(operand); + return res; + } + }]; +} + +def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, + Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>, + Results<(outs AnyStridedMemRef)> { + let summary = "transpose operation produces a new strided memref (metadata-only)"; + let description = [{ + The `linalg.transpose` op produces a strided memref whose sizes and strides + are a permutation of the original `view`. This is a pure metadata + transformation. + + Example: + + ```mlir + %1 = linalg.transpose %0 (i, j) -> (j, i) : memref + ``` + }]; + + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value view, " + "AffineMapAttr permutation, ArrayRef attrs = {}">]; + + let verifier = [{ + if (!permutation().isPermutation()) + return emitOpError("expected a permutation map"); + if (permutation().getNumDims() != getViewType().getRank()) + return emitOpError("expected a permutation map of same rank as the view"); + return success(); + }]; + + let extraClassDeclaration = [{ + static StringRef getPermutationAttrName() { return "permutation"; } + MemRefType getViewType() { return view()->getType().cast(); } + }]; +} + +def Linalg_YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>, + Arguments<(ins Variadic:$values)> { + let summary = "Linalg yield operation"; + let description = [{ + `linalg.yield` is a special terminator operation for blocks inside regions + in `linalg` generic ops. It returns values to the immediately enclosing + `linalg` generic op. + + Example: + + ```mlir + linalg.yield %f0, %f1 : f32, f32 + ``` + }]; +} + +#endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td new file mode 100644 index 0000000000000000000000000000000000000000..dd9e09b8eae78abdeac426ad6d0739c283235866 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -0,0 +1,616 @@ +//===- LinalgStructuredOps.td - Linalg dialect library ops -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for structured operations on buffers +// that correspond to underlying library calls (e.g. BLAS). +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_STRUCTURED_OPS +#define LINALG_STRUCTURED_OPS + +include "mlir/Dialect/AffineOps/AffineOpsBase.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" + +// The Linalg `NInputs` trait provides the API for ops that are known +// to have a specified number of inputs, all passed as operands. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NInputs : + NativeOpTrait<"linalg::NInputs<" # !cast(args_in) # ">::Impl"> {} + +// The Linalg `NOutputs` trait provides the API for ops that are known +// to have a specified number of outputs, all passed as operands. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NOutputs : + NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {} + +def ViewTraits : NativeOpTrait<"linalg::ViewTraits">; + +// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' +// interface. +def LinalgStructuredInterface : OpInterface<"LinalgOp"> { + let methods = [ + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod< + "Query the number of inputs and outputs from the current operation.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" + >, + InterfaceMethod< + "Query the output operands from the current operation.", + "Operation::operand_range", "getOutputs" + >, + InterfaceMethod< + "Query the input and output operands from the current operation.", + "Operation::operand_range", "getInputsAndOutputs" + >, + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod< + "Query the number of parallel loops within the current operation.", + "unsigned", "getNumParallelLoops" + >, + InterfaceMethod< + "Query the number of reduction loops within the current operation.", + "unsigned", "getNumReductionLoops" + >, + InterfaceMethod< + "Query the number of window loops within the current operation.", + "unsigned", "getNumWindowLoops" + >, + InterfaceMethod< + "Query the number of loops within the current operation.", + "unsigned", "getNumLoops">, + InterfaceMethod<"Query the input view at the given index.", + "Value ", "getInput", (ins "unsigned":$i) + >, + InterfaceMethod<"Query the output view at the given index.", + "Value ", "getOutput", (ins "unsigned":$i) + >, + InterfaceMethod<[{ + Query the index of the given input value, or `None` if the value is not + an input. + }], + "llvm::Optional", "getIndexOfInput", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the index of the given view value, or `None` if the value is not + an view. + }], + "llvm::Optional", "getIndexOfOutput", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the input view at the given index. + }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the type of the output view at the given index. + }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>, + + StaticInterfaceMethod<[{ + Create an operation of the current type with the given location, + operands, and attributes. + }], + "Operation *", "create", + (ins "OpBuilder &":$builder, "Location":$loc, + "ValueRange":$operands, + "ArrayRef":$attributes), [{ + return builder.create(loc, ArrayRef{}, operands, + attributes); + }] + >, + + /// Clone an operation with the given location and operands. This is used to + /// abstract away the optional underlying region creation. + InterfaceMethod<[{ + Clone the current operation with the given location and operands. This + is used to abstract away the optional underlying region creation. + }], + "Operation *", "clone", + (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ + BlockAndValueMapping map; + unsigned numRegions = op.getOperation()->getNumRegions(); + Operation *res = create(b, loc, operands, op.getAttrs()); + assert(res->getNumRegions() == numRegions && "inconsistent # regions"); + for (unsigned ridx = 0; ridx < numRegions; ++ridx) + op.getOperation()->getRegion(ridx).cloneInto( + &res->getRegion(ridx), map); + return res; + }] + > + ]; +} + +// Base Tablegen class for Linalg ops. +// Linalg ops that correspond to library calls operate on linalg::View as their +// first operands. These may be optionally followed by non-view operands +// depending on the specific Linalg op. +class LinalgStructuredBase_Op props> + : Op { + let parser = [{ return parseLinalgStructuredOp(parser, result); }]; + let printer = [{ printLinalgStructuredOp(p, *this); }]; +} + +class LinalgStructured_Op props> + : LinalgStructuredBase_Op { + code libraryCallName = [{ + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); + } + }]; +} + +//////////////////////////////////////////////////////////////////////////////// +// Concrete Linalg ops. +//////////////////////////////////////////////////////////////////////////////// +def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { + let description = [{ + Copies the data in the input view into the output view. + + Usage: + ```mlir + linalg.copy(%arg0, %arg1) : memref, + memref + ``` + + One possible lowering to loop form is: + ```mlir + %0 = linalg.dim %arg0, 0 : index + loop.for %i0 = %c0 to %0 step %c1 { + %1 = linalg.load %arg0[%i0] : memref + linalg.store %1, %arg1[%i0] : memref + } + ``` + + Optionally, can take `input_permutation` and `output_permutation` attributes + to reorder the dimensions of the input and output views. + + Usage: + ```mlir + linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j), + outputPermutation : (i, j, k) -> (k, j, i)} : + memref, + memref + ``` + + One possible lowering to loop form is: + ```mlir + %0 = linalg.dim %arg0, 0 + %1 = linalg.dim %arg0, 1 + %2 = linalg.dim %arg0, 2 + loop.for %i0 = %c0 to %{{.*}} step %c1 { + loop.for %i1 = %c0 to %{{.*}} step %c1 { + loop.for %i2 = %c0 to %{{.*}} step %c1 { + %3 = linalg.load %arg0[%i0, %i2, %i1] : + memref + linalg.store %3, %arg1[%i2, %i1, %i0] : + memref + ``` + + The views are expected to be compatible for correctness but this is not + enforced at the moment. + }]; + let arguments = (ins + AnyStridedMemRef:$input, + AnyStridedMemRef:$output, + OptionalAttr:$inputPermutation, + OptionalAttr:$outputPermutation); + // TODO(ntv) this should go away once the usage of OptionalAttr triggers + // emission of builders with default arguments left unspecified. + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value input, Value output", [{ + return build( + builder, result, input, output, AffineMapAttr(), AffineMapAttr()); + }]>]; + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + unsigned nPar = input()->getType().cast().getRank(); + MLIRContext *ctx = getContext(); + SmallVector iters( + nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + return ArrayAttr::get(iters, ctx); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} + +def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRef:$input, + AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + unsigned nPar = input()->getType().cast().getRank(); + MLIRContext *ctx = getContext(); + SmallVector iters( + nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + return ArrayAttr::get(iters, ctx); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} + +def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<0>); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + MLIRContext *ctx = getContext(); + return ArrayAttr::get( + StringAttr::get(getReductionIteratorTypeName(), ctx), ctx); + } + }]; +} + +def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<1>); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + MLIRContext *ctx = getContext(); + Attribute iters[2]{ + StringAttr::get(getParallelIteratorTypeName(), ctx), + StringAttr::get(getReductionIteratorTypeName(), ctx)}; + return ArrayAttr::get(iters, ctx); + } + }]; +} + +def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<2>); + let extraClassDeclaration = libraryCallName # [{ + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + MLIRContext *ctx = getContext(); + Attribute iters[3]{ + StringAttr::get(getParallelIteratorTypeName(), ctx), + StringAttr::get(getParallelIteratorTypeName(), ctx), + StringAttr::get(getReductionIteratorTypeName(), ctx)}; + return ArrayAttr::get(iters, ctx); + } + }]; +} + +def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { + let description = [{ + Generic n-D convolution as described in the TF documentation: + https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution + + ``` + output[b, x[0], ..., x[N-1], k] = + sum_{z[0], ..., z[N-1], q} + filter[z[0], ..., z[N-1], q, k] * + padded_input[b, + x[0] * strides[0] + dilation_rate[0] * z[0], + ..., + x[N-1] * strides[N-1] + dilation_rate[N-1] * z[N-1], + q] + ``` + }]; + + // TODO(ntv) padding. + // Following the TF source of truth above, strides and dilations are integer + // attributes of the same rank as the number of window dimensions. + let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input, + AnyStridedMemRef:$output, + OptionalAttr:$strides, + OptionalAttr:$dilations); + let extraClassDeclaration = libraryCallName # [{ + // TODO(ntv) extend to support more than 1 dimensions and potentially + // grouping too. + unsigned getNumBatchDimensions() { return 1; } + unsigned getNumInputFeatureDimensions() { return 1; } + unsigned getNumOutputFeatureDimensions() { return 1; } + + ArrayAttr indexing_maps(); + + ArrayAttr iterator_types() { + // Outer parallel loops are always the number of output dimensions; i.e. + // [ b, xs, q] in the TF notation above. + unsigned nPar = getOutputViewType(0).getRank(); + unsigned nRed = getNumInputFeatureDimensions(); + // Window loops are a special kind of reduction that is never tiled or + // parallelized across; i.e. [zs] in the TF notation above whose number + // match `xs` (i.e. 1 window loop per "image" dimension). + // This may evolve in the future. + unsigned nWin = + nPar - getNumBatchDimensions() - getNumInputFeatureDimensions(); + MLIRContext *ctx = getContext(); + SmallVector iters( + nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + iters.reserve(nPar + nRed + nWin); + iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx)); + iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx)); + return ArrayAttr::get(iters, ctx); + } + + int64_t getStride(unsigned i) { + assert(i < getNumWindowLoops()); + if (!strides().hasValue()) return 1; + return strides()->getValue()[i] + .cast().getValue().getSExtValue(); + } + + int64_t getDilation(unsigned i) { + assert(i < getNumWindowLoops()); + if (!dilations().hasValue()) return 1; + return dilations()->getValue()[i] + .cast().getValue().getSExtValue(); + } + }]; + let verifier = [{ return ::verify(*this); }]; +} + +class GenericOpBase : LinalgStructuredBase_Op { + let arguments = (ins Variadic:$views, + I64Attr:$args_in, + I64Attr:$args_out, + AffineMapArrayAttr:$indexing_maps, + ArrayAttr:$iterator_types, + OptionalAttr:$doc, + OptionalAttr:$fun, + OptionalAttr:$library_call); + let regions = (region AnyRegion:$region); + let extraClassDeclaration = [{ + SmallVector linalgTraitAttrNames() { + return SmallVector{ + getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(), + getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(), + getIteratorTypesAttrName() + }; + } + unsigned getNumInputs() { return args_in().getSExtValue(); } + unsigned getNumOutputs() { return args_out().getSExtValue(); } + FuncOp getFunction() { + auto moduleOp = getParentOfType(); + return fun().hasValue() ? + moduleOp.lookupSymbol(fun().getValue()) : FuncOp(); + } + StringRef getLibraryCallName() { + return library_call().hasValue() ? library_call().getValue() : ""; + } + AffineMap getIndexingMap(unsigned i) { + assert(i < getNumInputsAndOutputs()); + return indexing_maps().getValue()[i].cast().getValue(); + } + AffineMap getInputIndexingMap(unsigned i) { + assert(i < getNumInputs()); + return indexing_maps().getValue()[i].cast().getValue(); + } + AffineMap getOutputIndexingMap(unsigned i) { + assert(i < getNumOutputs()); + return indexing_maps().getValue()[i + getNumInputs()] + .cast().getValue(); + } + }]; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseGenericOp(parser, result); }]; +} + +def GenericOp : GenericOpBase<"generic"> { + let description = [{ + Generic Linalg op form where the key properties of the computation are + specified as attributes. In pretty form, a linalg.generic op is written as: + + ```mlir + linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + memref, + memref, + memref + ``` + + Where #trait_attributes is an alias of a dictionary attribute containing: + - args_in: an I64Attr representing the number of input (readonly) views + - args_out: an I64Attr representing the number of output (readwrite) views + - doc [optional]: a documentation string + - fun: a FlatSymbolRefAttr that must resolve to an existing function + symbol. To support inplace updates in a generic fashion, the signature + of the function must be: + ``` + fun([input views element types], [output views element types]) + -> ([output views element types]) + ``` + - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input + and output view. Such AffineMapAttr specifies the mapping between the + loops and the indexing within each view. + - library_call [optional]: a StringAttr containing the name of an + external library function that the linalg.generic operation maps to. + The external library is assumed to be dynamically linked and no strong + compile-time guarantees are provided. In the absence of such a library + call, linalg.generic will always lower to loops. + - iterator_types: an ArrayAttr specifying the type of the enclosing loops. + Each element of the list represents and iterator of one of the following + types: + parallel, reduction, window + + Example: + Defining a #matmul_trait attribute in MLIR can be done as follows: + ```mlir + func @fma(%a: f32, %b: f32, %c: f32) -> f32 { + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + return %e: f32 + } + #matmul_accesses = [ + (m, n, k) -> (m, k), + (m, n, k) -> (k, n), + (m, n, k) -> (m, n) + ] + #matmul_trait = { + doc = "C(m, n) += A(m, k) * B(k, n)", + fun = @fma, + indexing_maps = #matmul_accesses, + library_call = "linalg_matmul", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "reduction"] + } + ``` + + And can be reused in multiple places as: + ```mlir + linalg.generic #matmul_trait %A, %B, %C [other-attributes] : + memref, + memref, + memref + ``` + + This may lower to either: + ```mlir + call @linalg_matmul(%A, %B, %C) : + (memref, + memref, + memref) + -> () + ``` + + or IR resembling: + ```mlir + loop.for %m = %c0 to %M step %c1 { + loop.for %n = %c0 to %N step %c1 { + loop.for %k = %c0 to %K step %c1 { + %a = linalg.load %A[%m, %k] : memref + %b = linalg.load %B[%k, %n] : memref + %c = linalg.load %C[%m, %n] : memref + %d = call @func_of_elements(%a, %b, %c) + : (f32, f32, f32) -> (f32) + linalg.store %d, %C[%m, %n] : memref + } + } + } + ``` + }]; + let verifier = [{ return ::verify(*this); }]; +} + +def IndexedGenericOp : GenericOpBase<"indexed_generic"> { + let description = [{ + Indexed Generic Linalg op form where the key properties of the computation + are specified as attributes. In pretty form, a linalg.indexed_generic op is + written as: + + ```mlir + linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} : + memref, + memref, + memref + ``` + + Where #trait_attributes is an alias of a dictionary attribute containing: + - args_in: an I64Attr representing the number of input (readonly) views + - args_out: an I64Attr representing the number of output (readwrite) views + - doc [optional]: a documentation string + - fun: a FlatSymbolRefAttr that must resolve to an existing function + symbol. To support inplace updates in a generic fashion, the signature + of the function must be: + ``` + fun([index types of induction variables], [input views element types], + [output views element types]) -> ([output views element types]) + ``` + - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input + and output view. Such AffineMapAttr specifies the mapping between the + loops and the indexing within each view. + - library_call [optional]: a StringAttr containing the name of an + external library function that the linalg.indexed_generic operation + maps to. The external library is assumed to be dynamically linked and + no strong compile-time guarantees are provided. In the absence of such + a library call, linalg.indexed_generic will always lower to loops. + - iterator_types: an ArrayAttr they type of the enclosing loops; Each + element of the list represents and iterator of one of the following + types: + parallel, reduction, window + + Example: + Defining a #matmul_trait attribute in MLIR can be done as follows: + ```mlir + func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) + -> f32 + { + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + return %e: f32 + } + #matmul_accesses = [ + (m, n, k) -> (m, k), + (m, n, k) -> (k, n), + (m, n, k) -> (m, n) + ] + #matmul_trait = { + doc = "C(m, n) += A(m, k) * B(k, n)", + fun = @fma, + indexing_maps = #matmul_accesses, + library_call = "linalg_matmul", + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "reduction"] + } + ``` + + And can be reused in multiple places as: + ```mlir + linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] : + memref, + memref, + memref + ``` + + This may lower to either: + ```mlir + call @linalg_matmul(%A, %B, %C) : + (memref, + memref, + memref) + -> () + ``` + + or IR resembling: + ```mlir + loop.for %m = %c0 to %M step %c1 { + loop.for %n = %c0 to %N step %c1 { + loop.for %k = %c0 to %K step %c1 { + %a = linalg.load %A[%m, %k] : memref + %b = linalg.load %B[%k, %n] : memref + %c = linalg.load %C[%m, %n] : memref + %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c) + : (index, index, index, f32, f32, f32) -> (f32) + linalg.store %d, %C[%m, %n] : memref + } + } + } + ``` + }]; + let verifier = [{ return ::verify(*this); }]; +} + +#endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..e0d651806d3207d233e7aec8c21e2e04eaa72fdb --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -0,0 +1,157 @@ +//===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_ +#define MLIR_DIALECT_LINALG_LINALGTRAITS_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace OpTrait { +namespace linalg { + +/// This class provides the API for ops that are known to have a specified +/// number of inputs, all passed as operands. This is used as a trait like this: +/// +/// class DotOp : public Op::Impl> { +/// +template class NInputs { +public: + template + class Impl : public OpTrait::TraitBase::Impl> { + public: + static unsigned getNumInputs() { return N; } + }; +}; + +/// This class provides the API for ops that are known to have a specified +/// number of inputs, all passed as operands. This is used as a trait like this: +/// +/// class DotOp : public Op::Impl> { +/// +template class NOutputs { +public: + template + class Impl : public OpTrait::TraitBase::Impl> { + public: + static unsigned getNumOutputs() { return N; } + }; +}; + +/// This class provides the API for ops that are known to operate on views. This +/// trait must be used in conjunction with an op definition or a trait that +/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a +/// trait like this: +/// +/// class DotOp : public Op { +/// +template +class ViewTraits : public OpTrait::TraitBase { +private: + /// Return the number of input views. For internal use only. + unsigned nInputs() { + return cast(this->getOperation()).getNumInputs(); + } + /// Return the number of input views. For internal use only. + unsigned nOutputs() { + return cast(this->getOperation()).getNumOutputs(); + } + +public: + /// Return the `i`-th input view. + Value getInput(unsigned i) { + assert(i < nInputs()); + return this->getOperation()->getOperand(i); + } + /// Return the index of `view` in the list of input views if found, llvm::None + /// otherwise. + Optional getIndexOfInput(Value view) { + auto it = llvm::find(getInputs(), view); + if (it != getInputs().end()) + return it - getInputs().begin(); + return llvm::None; + } + /// Return the `i`-th input view type. + MemRefType getInputViewType(unsigned i) { + return getInput(i)->getType().template cast(); + } + /// Return the range over input views. + Operation::operand_range getInputs() { + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + nInputs()}; + } + /// Return the `i`-th output view. + Value getOutput(unsigned i) { + return this->getOperation()->getOperand(nInputs() + i); + } + /// Return the index of `view` in the list of output views if found, + /// llvm::None otherwise. + Optional getIndexOfOutput(Value view) { + auto it = llvm::find(getOutputs(), view); + if (it != getOutputs().end()) + return it - getOutputs().begin(); + return llvm::None; + } + /// Return the `i`-th output view type. + MemRefType getOutputViewType(unsigned i) { + return getOutput(i)->getType().template cast(); + } + /// Return the range over output views. + Operation::operand_range getOutputs() { + auto range = this->getOperation()->getOperands(); + return {range.begin() + nInputs(), + range.begin() + getNumInputsAndOutputs()}; + } + /// Return the number of input and output views. + unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } + /// Return the `i`-th view type. + MemRefType getViewType(unsigned i) { + return (i < nInputs()) ? getInputViewType(i) + : getOutputViewType(i - nInputs()); + } + /// Return the range over input and output views. + Operation::operand_range getInputsAndOutputs() { + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + } + unsigned getNumParallelLoops() { + return getNumIterators( + getParallelIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumReductionLoops() { + return getNumIterators( + getReductionIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumWindowLoops() { + return getNumIterators( + getWindowIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumLoops() { + return getNumIterators( + cast(this->getOperation()).iterator_types()); + } + static LogicalResult verifyTrait(Operation *op) { + auto nViews = cast(op).getNumInputsAndOutputs(); + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews))) + return failure(); + return success(); + } +}; + +} // namespace linalg +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h new file mode 100644 index 0000000000000000000000000000000000000000..abeda3e05528b6d7ba1106cb6cf7dcb9a07573cf --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -0,0 +1,61 @@ +//===- LinalgTypes.h - Linalg Types ---------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_ +#define MLIR_DIALECT_LINALG_LINALGTYPES_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Types.h" + +namespace mlir { +class MLIRContext; + +namespace linalg { +enum LinalgTypes { + Range = Type::FIRST_LINALG_TYPE, + LAST_USED_LINALG_TYPE = Range, +}; + +class LinalgDialect : public Dialect { +public: + explicit LinalgDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "linalg"; } + + /// Parse a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + /// Print a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; +}; + +/// A RangeType represents a minimal range abstraction (min, max, step). +/// It is constructed by calling the linalg.range op with three values index of +/// index type: +/// +/// ```mlir +/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) { +/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range +/// } +/// ``` +class RangeType : public Type::TypeBase { +public: + // Used for generic hooks in TypeBase. + using Base::Base; + /// Construction hook. + static RangeType get(MLIRContext *context) { + /// Custom, uniq'ed construction in the MLIRContext. + return Base::get(context, LinalgTypes::Range); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } +}; + +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..86cf6fdd02797aa06d1814816805c9d9cd053d9d --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -0,0 +1,48 @@ +//===- Passes.h - Linalg pass entry points ----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_PASSES_H_ +#define MLIR_DIALECT_LINALG_PASSES_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +class FuncOp; +class ModuleOp; +template class OpPassBase; + +namespace linalg { +std::unique_ptr> createLinalgFusionPass(); + +std::unique_ptr> +createLinalgTilingPass(ArrayRef tileSizes = {}); + +std::unique_ptr> +createLinalgPromotionPass(bool dynamicBuffers); + +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. +std::unique_ptr> createConvertLinalgToLoopsPass(); + +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr> createConvertLinalgToAffineLoopsPass(); + +/// Create a pass to convert Linalg operations to the LLVMIR dialect. +std::unique_ptr> createConvertLinalgToLLVMPass(); + +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f87938c943effd0b1ea7cda1241c8b6325549dc4 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td) +mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td new file mode 100644 index 0000000000000000000000000000000000000000..8f6762f004896c2605be0dc1e808b2f76b000f25 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -0,0 +1,108 @@ +//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the pattern definition file for declarative Linalg transformation. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_TRANSFORMS +#define LINALG_TRANSFORMS + +include "mlir/Dialect/Linalg/IR/LinalgOps.td" +include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" +include "mlir/Dialect/AffineOps/AffineOps.td" + +def HasNoLinalgTransformMarker : CPred<[{ + !$0.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) +}]>; + +class HasLinalgTransformMarker : CPred<[{ + $0.getAttrOfType( + LinalgTransforms::kLinalgTransformMarker) && + $0.getAttrOfType( + LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>; + +class IsProducedByOpOfType : + CPred<"isProducedByOpOfType<" # str # ">($0, $1)">; + +class AffineMapDomainHasDim : CPred<[{ + $0.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. + cast().getValue().getNumDims() ==}] # n # [{}]>; + +class HasOperandsOfType: CPred<[{ + llvm::any_of($0.getOperands(), + [](Value v) { + return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); + }) +}]>; + +//===----------------------------------------------------------------------===// +// Linalg fusion patterns. +//===----------------------------------------------------------------------===// +// +// In the future, tile sizes should be derived from op properties + machine +// description but we do not need to wait on this to start having useful +// patterns. +class TileAndFuseLinalgOp< + list sizes, list operandIndices, string value> : NativeCodeCall< + "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" # + StrJoinInt.result # "}, {" # StrJoinInt.result # "}," # + " \"" # value # "\")))" # + " return matchFailure();">; + +//===----------------------------------------------------------------------===// +// Linalg tiling patterns. +//===----------------------------------------------------------------------===// +// +// In the future, tile sizes should be derived from op properties + machine +// description but we do not need to wait on this to start having useful +// patterns. +// `permutation` is an optional parameter to specify the ordering of the +// tiled loops. If provided, it must be a list of integers with the same number +// of elements as `sizes`. +class TileLinalgOp sizes, string value, list permutation=[]> : + NativeCodeCall< + "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # + StrJoinInt.result # "}, \"" # value # "\", {" # + StrJoinInt.result # "})))" # + " return matchFailure();">; + +//===----------------------------------------------------------------------===// +// Linalg to loop patterns. +//===----------------------------------------------------------------------===// +class LinalgOpToLoops : NativeCodeCall< + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + +class LinalgOpToAffineLoops : NativeCodeCall< + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + " return matchFailure();">; + +//===----------------------------------------------------------------------===// +// Linalg to vector contraction patterns. +//===----------------------------------------------------------------------===// +class LinalgOpToVectorContraction : NativeCodeCall< + "if (failed(vectorizeGenericOp($_builder, $0))) " # + " return matchFailure();">; + +//===----------------------------------------------------------------------===// +// Linalg generic permutation patterns. +//===----------------------------------------------------------------------===// +class PermuteGenericLinalgOp permutation, string value> : + NativeCodeCall< + "if (failed(permuteGenericLinalgOp($_builder, $0, {" # + StrJoinInt.result # "}, \"" # value # "\"))) " # + " return matchFailure();">; + +//===----------------------------------------------------------------------===// +// Linalg promote subview operands. +//===----------------------------------------------------------------------===// +class LinalgOpPromoteSubviews : NativeCodeCall< + "if (failed(linalgOpPromoteSubviews($_builder, $0))) " # + " return matchFailure();">; +#endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h new file mode 100644 index 0000000000000000000000000000000000000000..757ee3ad1a7bdeaae72d896320fba07e15afe688 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -0,0 +1,96 @@ +//===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ +#define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace linalg { + +// Marker used as attribute name in generated Linalg rewriting transformations. +struct LinalgTransforms { + static const StringLiteral kLinalgTransformMarker; +}; + +namespace detail { +// Implementation detail of isProducedByOpOfType avoids the need for explicit +// template instantiations. +bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView, + function_ref isaOpType); +} // namespace detail + +// Returns true if the `consumedView` value use in `consumerOp` is produced by +// an op of type `OpTy`. This is used to implement use-def type information on +// buffers. +template +bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) { + return detail::isProducedByOpOfTypeImpl( + consumerOp, consumedView, [](Operation *op) { return isa(op); }); +} + +//////////////////////////////////////////////////////////////////////////////// +// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite +// patterns. As such, they must not call into `rewriter.erase/replace` APIs and +// it is the responsibility of the enclosing PatternRewriter to erase on +// success. +//////////////////////////////////////////////////////////////////////////////// + +/// Tiles `op` by `sizes` permuting the looops according to `permutation` +/// and sets the attribute `kLinalgTransformMarker` to `linalgMarker`. +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `permutation` +/// must be equal to the length of `tileSizes`. +/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with +/// `permutation = [1,2,0]`. All values in `permutation` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). An empty list +/// states for the identity permutation. +LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, + ArrayRef sizes, + StringRef linalgMarker, + ArrayRef permutation); + +/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and +/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. +LogicalResult tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker); + +/// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); + +/// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); + +/// Rewrite a linalg.generic into a suitable vector.contraction op. +LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op); + +/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` +/// and `iterator_types` permutated according to `permutation`. +LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, + ArrayRef permutation, + StringRef linalgMarker); + +/// Promote std.subviews feeding linalg operations +LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op); + +} // namespace linalg +} // namespace mlir + +#endif // DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..778d853aeefe91ae718f1b9eebd304792b7a4d67 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h @@ -0,0 +1,29 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_INTRINSICS_H_ +#define MLIR_DIALECT_LINALG_INTRINSICS_H_ + +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace linalg { +class CopyOp; +class FillOp; +class RangeOp; +class SliceOp; +namespace intrinsics { +using copy = mlir::edsc::intrinsics::OperationBuilder; +using fill = mlir::edsc::intrinsics::OperationBuilder; +using range = mlir::edsc::intrinsics::ValueBuilder; +using slice = mlir::edsc::intrinsics::ValueBuilder; +} // namespace intrinsics +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..996658b4c5c73797cf6d95541bf1b24cd9c3b9af --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -0,0 +1,226 @@ +//===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_UTILS_H_ +#define MLIR_DIALECT_LINALG_UTILS_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/EDSC/Helpers.h" + +#include "llvm/ADT/SetVector.h" + +namespace mlir { +class AffineExpr; +class AffineMap; +class OperationFolder; + +namespace edsc { + +/// A LoopRangeBuilder is a generic NestedBuilder for loop.for operations. +/// More specifically it is meant to be used as a temporary object for +/// representing any nested MLIR construct that is "related to" an mlir::Value +/// (for now an induction variable). +class LoopRangeBuilder : public NestedBuilder { +public: + /// Constructs a new loop.for and captures the associated induction + /// variable. A ValueHandle pointer is passed as the first argument and is the + /// *only* way to capture the loop induction variable. + LoopRangeBuilder(ValueHandle *iv, ValueHandle range); + LoopRangeBuilder(ValueHandle *iv, Value range); + LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range); + + LoopRangeBuilder(const LoopRangeBuilder &) = delete; + LoopRangeBuilder(LoopRangeBuilder &&) = default; + + LoopRangeBuilder &operator=(const LoopRangeBuilder &) = delete; + LoopRangeBuilder &operator=(LoopRangeBuilder &&) = default; + + /// The only purpose of this operator is to serve as a sequence point so that + /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is + /// scoped within a LoopRangeBuilder. + ValueHandle operator()(std::function fun = nullptr); +}; + +/// Helper class to sugar building loop.for loop nests from ranges. +/// This is similar to edsc::AffineLoopNestBuilder except it works on ranges +/// directly. In the current implementation it produces loop.for operations. +class LoopNestRangeBuilder { +public: + LoopNestRangeBuilder(ArrayRef ivs, + ArrayRef ranges); + LoopNestRangeBuilder(ArrayRef ivs, + ArrayRef ranges); + LoopNestRangeBuilder(ArrayRef ivs, + ArrayRef ranges); + edsc::ValueHandle operator()(std::function fun = nullptr); + +private: + SmallVector loops; +}; + +} // namespace edsc + +namespace linalg { +class LinalgDependenceGraph; + +struct FusionInfo { + LinalgOp originalProducer; + LinalgOp fusedProducer; +}; + +/// Checks whether the specific `producer` is the last write to exactly the +/// whole `consumedView`. This checks structural dominance, that the dependence +/// is a RAW without any interleaved write to any piece of `consumedView`. +bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value consumedView, + LinalgOp producer); + +/// Checks whether fusing the specific `producer` of the `consumedView` is +/// feasible. This checks `producer` is the last write of `consumedView` and +/// that no interleaved dependence would be violated (RAW, WAR or WAW). +bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, + Value consumedView, LinalgOp producer); + +/// Fuses producer into consumer if the producer is structurally feasible and +/// the fusion would not violate dependencies. +/// When non-null, the optional pointer `folder` is used to call into the +/// `createAndFold` builder method. If `folder` is null, the regular `create` +/// method is called. +Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + const LinalgDependenceGraph &graph, + OperationFolder *folder = nullptr); + +/// Returns the linearized list of all view dimensions in a linalgOp. Applying +/// the inverse, concatenated loopToOperandRangeMaps to this list allows the +/// derivation of loop ranges for any linalgOp. +template +SmallVector getViewSizes(ConcreteOp linalgOp) { + SmallVector res; + for (auto v : linalgOp.getInputsAndOutputs()) { + MemRefType t = v->getType().template cast(); + for (unsigned i = 0; i < t.getRank(); ++i) + res.push_back(edsc::intrinsics::dim(v, i)); + } + return res; +} + +/// Returns the values obtained by applying `map` to the list of values. +/// When non-null, the optional pointer `folder` is used to call into the +/// `createAndFold` builder method. If `folder` is null, the regular `create` +/// method is called. +SmallVector applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, ArrayRef values, + OperationFolder *folder = nullptr); + +struct TiledLinalgOp { + LinalgOp op; + SmallVector loops; +}; + +/// Performs standalone tiling of a single LinalgOp by `tileSizes`. +/// and permute the loop nest according to `permutation` +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `permutation` +/// must be equal to the length of `tileSizes`. +/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with +/// `permutation = [1,2,0]`. All values in `permutation` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). An empty list +/// states for the identity permutation. +/// Returns a struct containing the tiled loops in the specified order +/// and the cloned op if successful, llvm::None otherwise. +/// When non-null, the optional pointer `folder` is used to call into the +/// `createAndFold` builder method. If `folder` is null, the regular `create` +/// method is called. +Optional tileLinalgOp(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef permutation = {}, + OperationFolder *folder = nullptr); + +/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. +/// and permute the loop nest according to `permutation` +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `permutation` +/// must be equal to the length of `tileSizes`. +/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with +/// `permutation = [1,2,0]`. All values in `permutation` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). An empty list +/// states for the identity permutation. +/// Returns a struct containing the tiled loops in the specified order +/// and the cloned op if successful, llvm::None otherwise. +/// When non-null, the optional pointer `folder` is used to call into the +/// `createAndFold` builder method. If `folder` is null, the regular `create` +/// method is called. +Optional tileLinalgOp(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef permutation = {}, + OperationFolder *folder = nullptr); + +template +Optional tileLinalgOperation(OpBuilder &b, Operation *op, + Args... args) { + return tileLinalgOp(b, cast(op), args...); +} + +struct PromotionInfo { + Value buffer; + Value fullLocalView; + Value partialLocalView; +}; + +/// Promotes the `subViews` into a new buffer allocated at the insertion point +/// `b`. For now, promotion occurs in 3 steps: +/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). +/// 2. Take a full view on the buffer and `linalg.fill` it with zeros (use +/// float zero for now). +/// 3. Take a partial slice of the full view in step 2. and copy into it. +/// Infers statically sized buffers from subViews unless `dynamicBuffers` is +/// true. +/// +/// Returns a list of PromotionInfo which hold the promoted buffer and the +/// full and partial views indexing into the buffer. +SmallVector +promoteSubViews(OpBuilder &b, Location loc, ArrayRef subViews, + bool dynamicBuffers = false, OperationFolder *folder = nullptr); + +/// Returns all the operands of `linalgOp` that are not views. +/// Asserts that these operands are value types to allow transformations like +/// tiling to just use the values when cloning `linalgOp`. +SmallVector getAssumedNonViewOperands(LinalgOp linalgOp); + +/// Apply the permutation defined by `permutation` to `inVec`. +/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. +/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector +/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. +template +void applyPermutationToVector(SmallVector &inVec, + ArrayRef permutation) { + SmallVector auxVec(inVec.size()); + for (unsigned i = 0; i < permutation.size(); ++i) + auxVec[i] = inVec[permutation[i]]; + inVec = auxVec; +} + +/// Prepares the SubView promotion later performed by `promoteSubViews` +/// (where most of the transformation happens). It arranges the new +/// operands for `LinalgOp op` and deallocates the new buffer(s) +/// It is the entry point for declarative transformation +/// Returns the cloned `LinalgOp` with the new operands +LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, + llvm::SetVector subViews, + bool dynamicBuffers = false, + OperationFolder *folder = nullptr); + +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_UTILS_H_ diff --git a/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0fda882d3f54947bf72cb3891622c75942b63a69 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(LoopOps LoopOps) diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h new file mode 100644 index 0000000000000000000000000000000000000000..2617d7fd7839825def7fb6ee62749b0ccb26b7e4 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -0,0 +1,48 @@ +//===- Ops.h - Loop MLIR Operations -----------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines convenience types for working with loop operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LOOPOPS_OPS_H_ +#define MLIR_LOOPOPS_OPS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/LoopLikeInterface.h" + +namespace mlir { +namespace loop { + +class TerminatorOp; + +class LoopOpsDialect : public Dialect { +public: + LoopOpsDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "loop"; } +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/LoopOps/LoopOps.h.inc" + +// Insert `loop.terminator` at the end of the only region's only block if it +// does not have a terminator already. If a new `loop.terminator` is inserted, +// the location is specified by `loc`. If the region is empty, insert a new +// block first. +void ensureLoopTerminator(Region ®ion, Builder &builder, Location loc); + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +ForOp getForInductionVarOwner(Value val); + +} // end namespace loop +} // end namespace mlir +#endif // MLIR_LOOPOPS_OPS_H_ diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td new file mode 100644 index 0000000000000000000000000000000000000000..707b788aaa84654aa73a70aa15a53c6414813d36 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -0,0 +1,147 @@ +//===- Ops.td - Loop operation definitions ---------------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines MLIR loop operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LOOP_OPS +#define LOOP_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Transforms/LoopLikeInterface.td" + +def Loop_Dialect : Dialect { + let name = "loop"; + let cppNamespace = ""; +} + +// Base class for Loop dialect ops. +class Loop_Op traits = []> : + Op { + // For every standard op, there needs to be a: + // * void print(OpAsmPrinter &p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, + // OperationState &result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +def ForOp : Loop_Op<"for", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "for operation"; + let description = [{ + The "loop.for" operation represents a loop nest taking 3 SSA value as + operands that represent the lower bound, upper bound and step respectively. + The operation defines an SSA value for its induction variable. It has one + region capturing the loop body. The induction variable is represented as an + argument of this region. This SSA value always has type index, which is the + size of the machine word. The step is a value of type index, required to be + positive. + The lower and upper bounds specify a half-open range: the range includes the + lower bound but does not include the upper bound. + + The body region must contain exactly one block that terminates with + "loop.terminator". Calling ForOp::build will create such region and insert + the terminator, so will the parsing even in cases when it is absent from the + custom format. For example: + + loop.for %iv = %lb to %ub step %step { + ... // body + } + }]; + let arguments = (ins Index:$lowerBound, Index:$upperBound, Index:$step); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "Value lowerBound, Value upperBound, Value step"> + ]; + + let extraClassDeclaration = [{ + Block *getBody() { return ®ion().front(); } + Value getInductionVar() { return getBody()->getArgument(0); } + OpBuilder getBodyBuilder() { + return OpBuilder(getBody(), std::prev(getBody()->end())); + } + void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); } + void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); } + void setStep(Value step) { getOperation()->setOperand(2, step); } + }]; +} + +def IfOp : Loop_Op<"if", + [SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "if-then-else operation"; + let description = [{ + The "loop.if" operation represents an if-then-else construct for + conditionally executing two regions of code. The operand to an if operation + is a boolean value. The operation produces no results. For example: + + loop.if %b { + ... + } else { + ... + } + + The 'else' block is optional, and may be omitted. For + example: + + loop.if %b { + ... + } + }]; + let arguments = (ins I1:$condition); + let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "Value cond, bool withElseRegion"> + ]; + + let extraClassDeclaration = [{ + OpBuilder getThenBodyBuilder() { + assert(!thenRegion().empty() && "Unexpected empty 'then' region."); + Block &body = thenRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + OpBuilder getElseBodyBuilder() { + assert(!elseRegion().empty() && "Unexpected empty 'else' region."); + Block &body = elseRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + }]; +} + +def TerminatorOp : + Loop_Op<"terminator", [NativeOpTrait<"IsTerminator">]> { + let summary = "cf terminator operation"; + let description = [{ + "loop.terminator" is a special terminator operation for blocks inside + loops. It terminates the region. This operation does _not_ have a custom + syntax. However, `std` control operations omit the terminator in their + custom syntax for brevity. + + loop.terminator + }]; + + // No custom parsing/printing form. + let parser = ?; + let printer = ?; + + // Fully specified by traits. + let verifier = ?; +} + +#endif // LOOP_OPS diff --git a/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..90a61c4c194f1ea7f15a2f9ad0e51216eca3c508 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(QuantOps QuantOps) diff --git a/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h b/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h new file mode 100644 index 0000000000000000000000000000000000000000..1a141e3b1b359d4b7151874c240f7693fd8fce17 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h @@ -0,0 +1,67 @@ +//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines support utilities for interoperating with FakeQuant* based +// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note +// that FakeQuant* operators mix multiple concerns specific to how TFLite +// originally implemented quantization. As such, utilities here enforce +// opinions taken by that codebase (vs providing any amount of genericity). +// +// Specifically, it combines the following concerns, each of which would be +// independent variables in a more generic setup: +// - numBits and isSigned imply storage data type (uint8, int8, int16) +// - numBits < 8 is promoted to uint8 or int8 +// - "narrow_range" narrows the lower bound of the storage type's range by +// 1 +// - the specified min/max values are "nudged" so that the result has a zero +// that can be exactly expressed +// - min=max=0 implies scale=0 and zero_point=0 +// +// With the above assumptions applied, every conforming specified FakeQuant op +// can be represented by a UniformQuantizedType. This scheme is not expected to +// be generalized further in the future and should be considered to be a +// legacy set of rules. +// +// As canonically used in TensorFlow graphs, the presence of a FakeQuant node +// is a hint that the specific math represented here has been simulated at +// training time. As such, it is usually not advised to arbitrarily change +// quantization parameters derived from FakeQuant. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_ +#define MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_ + +#include "mlir/Dialect/QuantOps/QuantTypes.h" + +namespace mlir { +namespace quant { + +/// Converts per-layer FakeQuant attributes to the corresponding type. +/// In the event that the parameters cannot be converted, returns a nullptr +/// convertible Type and issues an appropriate error. +/// Note that there are multiple variants of a per-layer FakeQuant op, so +/// this function takes the attributes discretely vs taking a reference to the +/// originating op. +UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, + double rmin, double rmax, + bool narrowRange, Type expressedType, + bool isSigned = false); + +/// Converts per-channel FakeQuant attributes to the corresponding type. +/// In the event that the parameters cannot be converted, returns a nullptr +/// convertible Type and issues an appropriate error. +UniformQuantizedPerAxisType +fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, + ArrayRef rmins, ArrayRef rmax, + bool narrowRange, Type expressedType, + bool isSigned = false); +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANTOPS_FAKEQUANTSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/Passes.h b/mlir/include/mlir/Dialect/QuantOps/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..d3109775db2bb29d5a7ca64258aaa908126db632 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/Passes.h @@ -0,0 +1,41 @@ +//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines all of the passes owned by the quantization dialect. As +// things mature, it is expected that passes specific to certain frontend or +// backend dialects will move to those dialects directly. For now, they are +// incubated here. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANTOPS_PASSES_H +#define MLIR_DIALECT_QUANTOPS_PASSES_H + +#include + +namespace mlir { +class FuncOp; +template class OpPassBase; + +namespace quant { + +/// Creates a pass that converts quantization simulation operations (i.e. +/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. +std::unique_ptr> createConvertSimulatedQuantPass(); + +/// Creates a pass that converts constants followed by a qbarrier to a +/// constant whose value is quantized. This is typically one of the last +/// passes done when lowering to express actual quantized arithmetic in a +/// low level representation. Because it modifies the constant, it is +/// destructive and cannot be undone. +std::unique_ptr> createConvertConstPass(); + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANTOPS_PASSES_H diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h new file mode 100644 index 0000000000000000000000000000000000000000..9a4eec67c740f0d7086734b65a42129c705cd873 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h @@ -0,0 +1,41 @@ +//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANTOPS_QUANTOPS_H_ +#define MLIR_DIALECT_QUANTOPS_QUANTOPS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace quant { + +/// Defines the 'Quantization' dialect +class QuantizationDialect : public Dialect { +public: + QuantizationDialect(MLIRContext *context); + + /// Parse a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + /// Print a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/QuantOps/QuantOps.h.inc" + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANTOPS_QUANTOPS_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td new file mode 100644 index 0000000000000000000000000000000000000000..bbeb9419cc4088f2e85764180d850f06c22166c4 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -0,0 +1,258 @@ +//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for Quantization. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_QUANTOPS_QUANT_OPS_ +#define DIALECT_QUANTOPS_QUANT_OPS_ + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/QuantOps/QuantPredicates.td" + +def quant_Dialect : Dialect { + let name = "quant"; +} + +//===----------------------------------------------------------------------===// +// Base classes +//===----------------------------------------------------------------------===// + +class quant_Op traits> : + Op; + +//===----------------------------------------------------------------------===// +// Quantization casts +//===----------------------------------------------------------------------===// +// A QuantizeCast (qcast) represents a potential type shift from a quantizable +// type to a quantized type. +// +// At runtime, a qcast will apply the transformation expressed by its +// operand and result type. For flexibility during transformation, it is also +// possible to have a qcast that performs no transformation (both its +// operand and result type are quantizable). +// +// A qcast will typically originate from either: +// a) An expressed or implied constraint in the source dialect which signals +// that a certain level of quantization is possible or required. +// b) An inference made by a quantization algorithm indicating that a +// quantized representation may be acceptable. +// +// Especially early in transformation, it is common to have pairs of +// qcast/dcast at points where a transition to a quantized type is +// required. In addition, it is also common to have an identity qcast +// (where the operand and result type are not quantized) at all points where +// it is legal to use a quantized representation (but is not known to be +// acceptable). +def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> { + let arguments = (ins quant_RealValueType:$arg); + let results = (outs quant_RealValueType); +} + +// A DequantizeCast op (dcast) represents the inverse of a qcast, +// converting back from a quantized to quantizable (expressed) type. +// +// Like qcasts, a dcast is allowed to have both its operand and result +// as non quantized types. This facilitates transformations and marks edges +// where the computation must be carried out in the expressed type. +// +// Especially early in transformation, it is common to have dcasts on +// all operands to ops that must operate with the expressed type (typically +// math ops prior to lowering to target-specific, quantized kernels). +def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> { + let arguments = (ins quant_RealValueType:$arg); + let results = (outs quant_RealValueType); +} + +// A StorageCast (scast) represents a cast from or to a type based on the +// storage type and a type based on a corresponding quantized type. +// +// This op exists to ensure type coherency for between parts of the computation +// which are operating directly on an underlying storage type and those which +// operate on quantized values. +// +// Examples from storage to quantized type: +// i8 -> !quant<"uniform[i8:f32]{1.0}"> +// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> +// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> +def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> { + let arguments = (ins quant_RealOrStorageValueType:$arg); + let results = (outs quant_RealOrStorageValueType); + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// Training integration and instrumentation ops +//===----------------------------------------------------------------------===// + +def quant_ConstFakeQuant : quant_Op<"const_fake_quant", + [SameOperandsAndResultType, NoSideEffect]> { + let summary = + "Simulates the effect of uniform quantization with const range."; + + let description = [{ + Given a const min, max, num_bits and narrow_range attribute, applies the + same uniform quantization simulation as is done by the TensorFlow + fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility + method and the quant-convert-simulated-quantization pass for futher details. + }]; + + let arguments = (ins + F32Tensor:$inputs, + F32Attr:$min, + F32Attr:$max, + // The bitwidth of the quantization; between 2 and 16, inclusive. + I64Attr:$num_bits, + // Quantization range starts from 0 or 1; starts from 1 if true. + DefaultValuedAttr:$narrow_range, + // The sign of the quantization. + DefaultValuedAttr:$is_signed + ); + + let results = (outs + F32Tensor:$outputs + ); +} + +def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis", + [SameOperandsAndResultType, NoSideEffect]> { + let summary = + "Simulates the effect of per axis uniform quantization with const range."; + + let description = [{ + Given a const min, max, num_bits and narrow_range attribute, applies the + same per axis uniform quantization simulation as is done by the TensorFlow + fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType() + utility method and the quant-convert-simulated-quantization pass for futher + details. + }]; + + let arguments = (ins + F32Tensor:$inputs, + F32ArrayAttr:$min, + F32ArrayAttr:$max, + // The quantized dimension of the inputs tensor. + I64Attr:$axis, + // The bitwidth of the quantization; between 2 and 16, inclusive. + I64Attr:$num_bits, + // Quantization range starts from 0 or 1; starts from 1 if true. + DefaultValuedAttr:$narrow_range, + // The sign of the quantization. + DefaultValuedAttr:$is_signed + ); + + let results = (outs + F32Tensor:$outputs + ); +} + +def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> { + let summary = + "Indicates that statistics are resolved by reference."; + + let description = [{ + This op acts as an identity that, when encountered at runtime, should result + in statistics being collected about about the value of its operand/result. + Such statistics will be stored with the provided key, allowing this node + to later be converted to a 'stats' op if statistics with that key have been + encountered. + }]; + + let arguments = (ins + quant_RealValueType:$arg, + StrAttr:$statsKey + ); + let results = (outs quant_RealValueType); +} + +def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> { + let summary = + "Identity op which associates statistics with the value."; + + let description = [{ + Associates statistics about the runtime ranges of values observed for + evaluations of this node. + + Statistics about the entire type are reported in the 'layerStats' attribute + and those for each axis, in the (optional) `axisStats` attribute. The + interpretation of each is determined by the last dimension of its shape. + Currently, only dim=2 is supported, which is interpreted as [min, max]. + + `layerStats` must be a rank 1 tensor: [2] + `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size + splitted by the `axis` dimension. For example: + , axis=3 => N=2 + , axis=2 => N=6 + }]; + + let arguments = (ins + quant_RealValueType:$arg, + ElementsAttr:$layerStats, + OptionalAttr:$axisStats, + OptionalAttr:$axis); + let results = (outs quant_RealValueType); + + let verifier = [{ + auto tensorArg = arg()->getType().dyn_cast(); + if (!tensorArg) return emitOpError("arg needs to be tensor type."); + + // Verify layerStats attribute. + { + auto layerStatsType = layerStats().getType(); + if (!layerStatsType.getElementType().isa()) { + return emitOpError( + "layerStats must have a floating point element type"); + } + if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { + return emitOpError("layerStats must have shape [2]"); + } + } + // Verify axisStats (optional) attribute. + if (axisStats()) { + if (!axis()) return emitOpError("axis must be specified for axisStats"); + + auto shape = tensorArg.getShape(); + auto argSliceSize = std::accumulate(std::next(shape.begin(), + axis()->getSExtValue()), shape.end(), 1, std::multiplies()); + + auto axisStatsType = axisStats()->getType(); + if (!axisStatsType.getElementType().isa()) { + return emitOpError("axisStats must have a floating point element type"); + } + if (axisStatsType.getRank() != 2 || + axisStatsType.getDimSize(1) != 2 || + axisStatsType.getDimSize(0) != argSliceSize) { + return emitOpError("axisStats must have shape [N,2] " + "where N = the slice size defined by the axis dim"); + } + } + return success(); + }]; +} + +def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> { + let summary = + "Indicates that one point of the computation is coupled to another."; + + let description = [{ + Ordinarily, relationships between ops for the purposes of determining + compatible quantized types is explicit based on the use-def chain. However, + in some situations, a use may be separated from its def by arbitrary + external connections. In such a case, during analysis, all coupled_ref + nodes in a module which share a coupledKey will be considered to be + directly connected as via an identity op for the purpose of type inference. + }]; + + let arguments = (ins + quant_RealValueType:$arg, + StrAttr:$coupledKey); + let results = (outs quant_RealValueType); +} + +#endif // DIALECT_QUANTOPS_QUANT_OPS_ diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td new file mode 100644 index 0000000000000000000000000000000000000000..7225dcc72db1ee4e652c0b80f2d982674d52151b --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td @@ -0,0 +1,63 @@ +//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Predicates for types in the Quantization dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_QUANTOPS_QUANT_PREDICATES_ +#define DIALECT_QUANTOPS_QUANT_PREDICATES_ + +//===----------------------------------------------------------------------===// +// Quantization type definitions +//===----------------------------------------------------------------------===// + +class quant_TypedPrimitiveOrContainer : + Type.predicate, + VectorOf<[etype]>.predicate]>, + "primitive/tensor/vector of " # etype.description>; + +// An implementation of QuantizedType. +def quant_QuantizedType : + Type()">, "QuantizedType">; + +// A primitive type that can represent a real value. This is either a +// floating point value or a quantized type. +def quant_RealPrimitiveType : + Type, + "real valued primitive (float or quantized type)">; + +// A primitive type that can represent a storage value. This is either an +// integer or quantized type. +def quant_StoragePrimitiveType : + Type, + "quantized storage primitive (integer or quantized type)">; + +// A primitive or container of RealPrimitiveType. +def quant_RealValueType : + quant_TypedPrimitiveOrContainer; + +// A primitive or container of StoragePrimitiveType. +def quant_StorageValueType : + quant_TypedPrimitiveOrContainer; + +// Either a real valued or storage primitive or container type. +def quant_RealOrStorageValueType : + Type>; + +// An implementation of UniformQuantizedType. +def quant_UniformQuantizedType : + Type()">, "UniformQuantizedType">; + +// Predicate for detecting a container or primitive of UniformQuantizedType. +def quant_UniformQuantizedValueType : + quant_TypedPrimitiveOrContainer; + +#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_ diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h new file mode 100644 index 0000000000000000000000000000000000000000..daeb03744608d8d49442618910c33bface4d3bb7 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h @@ -0,0 +1,402 @@ +//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ +#define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace quant { + +class QuantizedIntegerType; + +namespace detail { + +struct QuantizedTypeStorage; +struct AnyQuantizedTypeStorage; +struct UniformQuantizedTypeStorage; +struct UniformQuantizedPerAxisTypeStorage; + +} // namespace detail + +namespace QuantizationTypes { +enum Kind { + Any = Type::FIRST_QUANTIZATION_TYPE, + UniformQuantized, + UniformQuantizedPerAxis, + LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis, +}; +} // namespace QuantizationTypes + +/// Enumeration of bit-mapped flags related to quantized types. +namespace QuantizationFlags { +enum FlagValue { + // Indicates that the storage type should be interpreted as a signed + // integer. The default is to interpret it as an unsigned value. + Signed = 1, +}; +} // namespace QuantizationFlags + +/// Base class for all quantized types known to this dialect. +/// All quantized types have: +/// - storageType: The (narrower) numeric type that is being used to +/// approximate some expressed type. +/// - expressedType: The type that is being approximated. +/// +/// The base class provides generic support for manipulating the types based +/// on these fields. +class QuantizedType : public Type { +public: + using ImplType = detail::QuantizedTypeStorage; + using Type::Type; + + /// The maximum number of bits supported for storage types. + static constexpr unsigned MaxStorageBits = 32; + + static LogicalResult + verifyConstructionInvariants(Optional loc, MLIRContext *context, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Support method to enable LLVM-style type casting. + static bool classof(Type type) { + return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE && + type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE; + } + + /// Gets the minimum possible stored by a storageType. storageTypeMin must + /// be greater than or equal to this value. + static int64_t getDefaultMinimumForInteger(bool isSigned, + unsigned integralWidth) { + if (isSigned) { + return llvm::minIntN(integralWidth); + } + return 0; + } + + /// Gets the maximum possible stored by a storageType. storageTypeMax must + /// be less than or equal to this value. + static int64_t getDefaultMaximumForInteger(bool isSigned, + unsigned integralWidth) { + if (isSigned) { + return llvm::maxIntN(integralWidth); + } + return llvm::maxUIntN(integralWidth); + } + + /// Gets the original expressed type that this quantized type approximates. + /// Note that this presumes that the quantized type was always derived from + /// a floating point type, which in the broadest definition, is not true (i.e. + /// it could be some form of integral, fixed type or affine type in its own + /// right); however, at the high level, no examples of such usage are + /// presently known and the restriction serves some useful purposes (such as + /// always being able to reverse a transformation or measure error). In most + /// cases, this will be f32. + Type getExpressedType() const; + + /// Gets the flags associated with this type. Typically a more specific + /// accessor is appropriate. + unsigned getFlags() const; + + // Convenience helpers. + /// Whether the storage type should be interpreted as a signed quantity + /// (true) or an unsigned value (false). + bool isSigned() const { + return (getFlags() & QuantizationFlags::Signed) == + QuantizationFlags::Signed; + } + + /// Gets the underlying type used for to store values. Note that this may + /// be signed or unsigned. Use the isSigned() accessor to differentiate. + Type getStorageType() const; + + /// The minimum value that storageType can take. + int64_t getStorageTypeMin() const; + + /// The maximum value that storageType can take. + int64_t getStorageTypeMax() const; + + /// Gets the integral bit width that the underlying storage type can exactly + /// represent. For integral storage types, this will just be their width. + unsigned getStorageTypeIntegralWidth() const; + + /// Returns whether the candidateExpressedType is a match for this + /// QuantizedType. This will be true if the candidate type is either a + /// primitive type or a container type whose element type equals this + /// QuantizedType's expressed type. + /// Examples of compatible candidateExpressedType: + /// !quant.uniform =~ f32 + /// !quant.uniform =~ tensor<4xf32> + bool isCompatibleExpressedType(Type candidateExpressedType); + + /// Returns the element type as a QuantizedType or nullptr if it is not + /// a quantized type. If the type is primitive, returns that. If it is a + /// container (vector/tensor), return the element type. + /// Examples: + /// !quant.uniform -> !quant.uniform + /// tensor<4x!quant.uniform -> quant.uniform + static QuantizedType getQuantizedElementType(Type primitiveOrContainerType); + + /// Casts from a type based on the storageType to a corresponding type based + /// on this type (returns nullptr if the cast is not valid). + /// Examples: + /// i8 -> !quant.uniform + /// tensor<4xi8> -> tensor<4x!quant.uniform> + /// vector<4xi8> -> vector<4x!quant.uniform> + Type castFromStorageType(Type candidateType); + + /// Casts from a type based on a QuantizedType to a corresponding type based + /// on the storageType (returns nullptr if the cast is not valid). + /// This is the inverse of castFromStorageType(). + static Type castToStorageType(Type quantizedType); + + /// Casts from a type based on the expressedType to a corresponding type based + /// on this type (returns nullptr if the cast is not valid). + /// Examples: + /// f32 -> !quant.uniform + /// tensor<4xf32> -> tensor<4x!quant.uniform> + /// vector<4xf32> -> vector<4x!quant.uniform> + Type castFromExpressedType(Type candidateType); + + /// Casts from a type based on QuantizedType to a corresponding type based + /// on the expressedType (returns nullptr if the cast is not valid). + /// This is the inverse of castFromExpressedType. + static Type castToExpressedType(Type quantizedType); + + /// Casts from a type based on the expressedType to the equivalent type + /// based on storageType by way of this QuantizedType. Equivalent to: + /// QuantizedType::castToStorageType(castFromExpressedType(candidateType)) + /// (but with validity checks). + /// Example (for this = !quant.uniform): + /// tensor<4xf32> -> tensor<4xi8> + Type castExpressedToStorageType(Type candidateType); + +private: + /// Hide the following methods inherited from `Type`. It is almost certainly + /// a bug to call them from a `QuantizedType` object. Users should call + /// `getStorageType` or `getExpressedType` to get the underlying types + /// they want to inspect. + using Type::isBF16; + using Type::isF16; + using Type::isF32; + using Type::isF64; + using Type::isIndex; + using Type::isInteger; +}; + +/// A quantized type that maps storage to/from expressed types in an +/// unspecified way. +/// +/// Typical syntax: +/// quant.any +/// quant.any +/// quant.any> +/// +/// Note that for the any type, the expressed type is optional. +class AnyQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; } + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static AnyQuantizedType get(unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static AnyQuantizedType getChecked(unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax, Location location); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult + verifyConstructionInvariants(Optional loc, MLIRContext *context, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); +}; + +/// Represents a family of uniform, quantized types. +/// +/// Each instance of this type expresses a mapping between real values (most +/// often expressed in floating point f32) and quantized values (either fixed +/// point or affine). +/// +/// The relationship is: +/// real_value = scale * (quantized_value - zero_point) +/// +/// It is used as part of high level graph transformations that have the goal +/// of re-expressing parts of a computation in terms of this common form for +/// more efficient execution at runtime. In addition, it is designed to be +/// expressive enough to facilitate lowering to precise types and operations +/// in target hardware. +/// +/// As a high-level type, focused on intermediate passes, this type holds +/// opinions consistent with high-level usage. If lowering math kernels below +/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific +/// instruction sets), it is expected that the information expressed here +/// will be used to drive low level codegen and target specific type selection, +/// but this type will likely be erased in the process. +/// +/// Syntax synopsis: +/// Per-layer, all parameters expressed: +/// !quant +/// Per-layer, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// Scale: A legal double value +/// ZeroPoint: An integer value +class UniformQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static UniformQuantizedType get(unsigned flags, Type storageType, + Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static UniformQuantizedType + getChecked(unsigned flags, Type storageType, Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, + Location location); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verifyConstructionInvariants( + Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == QuantizationTypes::UniformQuantized; + } + + /// Gets the scale term. The scale designates the difference between the real + /// values corresponding to consecutive quantized values differing by 1. + double getScale() const; + + /// Gets the storage value corresponding to the real value 0 in the affine + /// equation. + int64_t getZeroPoint() const; + + // Fixed point values are real numbers divided by a scale. + // Currently, only signed storage types are treated as fixed point. + // A fixed point value can be obtained from an affine value by subtracting + // the zeroPoint. + // In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } +}; + +/// Represents per-axis (also known as per-channel quantization). +/// +/// Syntax synopsis: +/// Per-axis, all parameters expressed: +/// !quant +/// Per-axis, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// QuantizedDim: An integer value +/// QuantParams: (Scale ':' ZeroPoint)+ +/// Scale: A legal double value +/// ZeroPoint: An integer value +class UniformQuantizedPerAxisType + : public Type::TypeBase { +public: + using Base::Base; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static UniformQuantizedPerAxisType + get(unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static UniformQuantizedPerAxisType + getChecked(unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax, Location location); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verifyConstructionInvariants( + Optional loc, MLIRContext *context, unsigned flags, + Type storageType, Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == QuantizationTypes::UniformQuantizedPerAxis; + } + + /// Gets the quantization scales. The scales designate the difference between + /// the real values corresponding to consecutive quantized values differing + /// by 1. The ith scale corresponds to the ith slice in the + /// quantized_dimension. + ArrayRef getScales() const; + + /// Gets the storage values corresponding to the real value 0 in the affine + /// equation. The ith zero point corresponds to the ith slice in the + /// quantized_dimension. + ArrayRef getZeroPoints() const; + + /// Specifies the dimension of the Tensor's shape that the scales and + /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + /// with quantization params: + /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1 + /// will be quantized across the second dimension of t. + /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + int32_t getQuantizedDimension() const; + + /// Fixed point values are real numbers divided by a scale. + /// Currently, only signed storage types are treated as fixed point. + /// A fixed point value can be obtained from an affine value by subtracting + /// the zeroPoint. + /// In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { + if (!isSigned()) + return false; + return llvm::all_of(getZeroPoints(), + [](int64_t zeroPoint) { return zeroPoint != 0; }); + } +}; + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h b/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c40b9e6f0265d73142032964ee480fc97001c15e --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h @@ -0,0 +1,61 @@ +//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_ +#define MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_ + +namespace mlir { +class Attribute; +class Type; + +namespace quant { +class QuantizedType; +class UniformQuantizedType; +class UniformQuantizedValueConverter; + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType(). +/// Returns nullptr if the conversion is not supported. On success, stores the +/// converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, + Type &outConvertedType); + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType() and casted to an +/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On +/// success, stores the converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttrUniform(Attribute realValue, + UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType); +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h new file mode 100644 index 0000000000000000000000000000000000000000..7c74fc56b8f03f4ef6b0af6cb41fb768c19de034 --- /dev/null +++ b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h @@ -0,0 +1,218 @@ +//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ +#define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ + +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" + +namespace mlir { +namespace quant { + +/// Performs type conversion from an arbitrary input type to a type +/// that is expressed by a QuantizedType. +/// +/// This handles cases where the inputType is a supported primitive type +/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported +/// elemental type. +/// +/// Since conversion often involves introspecting some attributes of the +/// input type in order to determine how to represent it, this is a two step +/// process. +struct ExpressedToQuantizedConverter { + /// Creates a converter for the given input type. + static const ExpressedToQuantizedConverter forInputType(Type inputType); + + /// Converts the inputType to be based on the given elemental type, + /// returning the new type (or nullptr and emit an error on failure). + Type convert(QuantizedType elementalType) const; + + /// Whether the conversion is legal. + explicit operator bool() const { return (bool)expressedType; } + + /// The input type that is being converted from. + /// This may be an elemental or composite type. + const Type inputType; + + /// Supported, elemental expressed type (i.e. f32). + /// Will be nullptr if conversion is not supported. + const Type expressedType; +}; + +/// Reference implementation of converting between real numbers and values +/// represented by a UniformQuantizedType. +/// Note that this is not expected to be speedy and may be superseded eventually +/// by a more optimal implementation. +/// Also, the interface assumes that quantization is done per-layer and will +/// need to be wider for various per-channel schemes. As such, this is a +/// placeholder. +class UniformQuantizedValueConverter { +public: + explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType) + : UniformQuantizedValueConverter( + uniformType.getScale(), + static_cast(uniformType.getZeroPoint()), + static_cast(uniformType.getStorageTypeMin()), + static_cast(uniformType.getStorageTypeMax()), + uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { + assert(uniformType.getExpressedType().isa()); + assert(uniformType.getStorageType().isa()); + } + + UniformQuantizedValueConverter(double scale, double zeroPoint, + double clampMin, double clampMax, + uint32_t storageBitWidth, bool isSigned) + : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin), + clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint), + clampMinDouble(clampMin), clampMaxDouble(clampMax), + storageBitWidth(storageBitWidth), isSigned(isSigned), + roundMode(APFloat::rmNearestTiesToAway) {} + + UniformQuantizedValueConverter(double scale, double zeroPoint, + APFloat clampMin, APFloat clampMax, + uint32_t storageBitWidth, bool isSigned) + : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin), + clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint), + clampMinDouble(clampMin.convertToDouble()), + clampMaxDouble(clampMax.convertToDouble()), + storageBitWidth(storageBitWidth), isSigned(isSigned), + roundMode(APFloat::rmNearestTiesToAway) {} + + virtual APInt quantizeFloatToInt(APFloat expressedValue) const { + // This function is a performance critical code path in quantization + // since it runs for each single float parameter value. + + // Specialize f32->u8/i8 case to optimize performance. + if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() && + storageBitWidth == 8 && + roundMode == llvm::APFloatBase::rmNearestTiesToAway) { + return quantizeF32ToInt8(expressedValue); + } + + bool lossy; + expressedValue.convert(scale.getSemantics(), roundMode, &lossy); + // fixedpoint = clamp(clampMin, clampMax, ( + // roundHalfToEven(expressed / scale) + zeroPoint)) + APFloat scaled = (expressedValue / scale); + scaled.roundToIntegral(roundMode); + scaled.add(zeroPoint, roundMode); + APFloat fixedpoint = llvm::minimum(scaled, clampMax); + fixedpoint = llvm::maximum(fixedpoint, clampMin); + + llvm::APSInt result(storageBitWidth, !isSigned); + fixedpoint.convertToInteger(result, roundMode, &lossy); + + return std::move(result); + } + + int64_t quantizeFloatToInt64(APFloat expressedValue) const { + APInt qValue = quantizeFloatToInt(expressedValue); + return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); + } + + virtual ~UniformQuantizedValueConverter() {} + +private: + // An optimized implementation to quantize f32 to i8/u8 with C++ native + // arithmetic. + virtual APInt quantizeF32ToInt8(APFloat expressedValue) const { + assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle()); + assert(storageBitWidth == 8); + assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway); + + const float realValue = expressedValue.convertToFloat(); + + const double scaled = realValue / scaleDouble + zeroPointDouble; + // Round to nearest integer with halfway cases rounded away from zero. + const double scaledRounded = std::round(scaled); + const double clamped = + std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble); + + uint64_t signlessResult; + if (isSigned) { + int64_t clampedInt = static_cast(clamped); + memcpy(&signlessResult, &clampedInt, sizeof(clampedInt)); + } else { + signlessResult = static_cast(clamped); + } + return APInt(storageBitWidth, signlessResult); + } + + // Keep both APFloat and double versions of the quantization parameters + // around since they will be used in generic and specialized arithmetic, + // respectively. + const APFloat scale; + const APFloat zeroPoint; + const APFloat clampMin; + const APFloat clampMax; + + const double scaleDouble; + const double zeroPointDouble; + const double clampMinDouble; + const double clampMaxDouble; + + const uint32_t storageBitWidth; + const bool isSigned; + const llvm::APFloat::roundingMode roundMode; +}; + +/// An utility class to quantize an attribute by the per-axis quantization +/// parameters. The size of the quantization dim in the converted elements +/// attribute should matche the size of of scales/zeroPoints vectors in the +/// quantization parameters. +class UniformQuantizedPerAxisValueConverter { +public: + explicit UniformQuantizedPerAxisValueConverter( + UniformQuantizedPerAxisType uniformType) + : scales(uniformType.getScales()), + zeroPoints(uniformType.getZeroPoints()), + clampMin(static_cast(uniformType.getStorageTypeMin())), + clampMax(static_cast(uniformType.getStorageTypeMax())), + storageBitWidth(uniformType.getStorageTypeIntegralWidth()), + isSigned(uniformType.isSigned()), + quantizationDim(uniformType.getQuantizedDimension()) { + assert(uniformType.getExpressedType().isa()); + assert(uniformType.getStorageType().isa()); + assert(scales.size() == zeroPoints.size()); + } + + /// Quantize an Attribute by the quantization parameters. Return nullptr if + /// the conversion fails or the input array isn't an ElementsAttr. + ElementsAttr convert(Attribute realValue); + +private: + /// Quantize an DenseFPElementsAttr by the quantization parameters. + DenseElementsAttr convert(DenseFPElementsAttr attr); + + /// Get a uniform converter for the index-th chunk along the quantizationDim. + /// All the elements in this chunk is quantized by the returned converter. + UniformQuantizedValueConverter getPerChunkConverter(int index) const { + UniformQuantizedValueConverter converter(scales[index], zeroPoints[index], + clampMin, clampMax, + storageBitWidth, isSigned); + return converter; + } + + const ArrayRef scales; + const ArrayRef zeroPoints; + const APFloat clampMin; + const APFloat clampMax; + const uint32_t storageBitWidth; + const bool isSigned; + int32_t quantizationDim; +}; + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/SDBM/SDBM.h b/mlir/include/mlir/Dialect/SDBM/SDBM.h new file mode 100644 index 0000000000000000000000000000000000000000..c8a0eec8ca84c88560b87ce22b9d39234d63290b --- /dev/null +++ b/mlir/include/mlir/Dialect/SDBM/SDBM.h @@ -0,0 +1,197 @@ +//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined +// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SDBM_SDBM_H +#define MLIR_DIALECT_SDBM_SDBM_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir { + +class MLIRContext; +class SDBMDialect; +class SDBMExpr; +class SDBMTermExpr; + +/// A utility class for SDBM to represent an integer with potentially infinite +/// positive value. This uses the largest value of int64_t to represent infinity +/// and redefines the arithmetic operators so that the infinity "saturates": +/// inf + x = inf, +/// inf - x = inf. +/// If a sum of two finite values reaches the largest value of int64_t, the +/// behavior of IntInfty is undefined (in practice, it asserts), similarly to +/// regular signed integer overflow. +class IntInfty { +public: + constexpr static int64_t infty = std::numeric_limits::max(); + + /*implicit*/ IntInfty(int64_t v) : value(v) {} + + IntInfty &operator=(int64_t v) { + value = v; + return *this; + } + + static IntInfty infinity() { return IntInfty(infty); } + + int64_t getValue() const { return value; } + explicit operator int64_t() const { return value; } + + bool isFinite() { return value != infty; } + +private: + int64_t value; +}; + +inline IntInfty operator+(IntInfty lhs, IntInfty rhs) { + if (!lhs.isFinite() || !rhs.isFinite()) + return IntInfty::infty; + + // Check for overflows, treating the sum of two values adding up to INT_MAX as + // overflow. Convert values to unsigned to get an extra bit and avoid the + // undefined behavior of signed integer overflows. + assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 || + static_cast(lhs.getValue()) + + static_cast(rhs.getValue()) < + static_cast(std::numeric_limits::max())) && + "IntInfty overflow"); + // Check for underflows by converting values to unsigned to avoid undefined + // behavior of signed integers perform the addition (bitwise result is same + // because numbers are required to be two's complement in C++) and check if + // the sign bit remains negative. + assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 || + ((static_cast(lhs.getValue()) + + static_cast(rhs.getValue())) >> + 63) == 1) && + "IntInfty underflow"); + + return lhs.getValue() + rhs.getValue(); +} + +inline bool operator<(IntInfty lhs, IntInfty rhs) { + return lhs.getValue() < rhs.getValue(); +} + +inline bool operator<=(IntInfty lhs, IntInfty rhs) { + return lhs.getValue() <= rhs.getValue(); +} + +inline bool operator==(IntInfty lhs, IntInfty rhs) { + return lhs.getValue() == rhs.getValue(); +} + +inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); } + +/// Striped difference-bound matrix is a representation of an integer set bound +/// by a system of SDBMExprs interpreted as inequalities "expr <= 0". +class SDBM { +public: + /// Obtain an SDBM from a list of SDBM expressions treated as inequalities and + /// equalities with zero. + static SDBM get(ArrayRef inequalities, + ArrayRef equalities); + + void getSDBMExpressions(SDBMDialect *dialect, + SmallVectorImpl &inequalities, + SmallVectorImpl &equalities); + + void print(raw_ostream &os); + void dump(); + + IntInfty operator()(int i, int j) { return at(i, j); } + +private: + /// Get the given element of the difference bounds matrix. First index + /// corresponds to the negative term of the difference, second index + /// corresponds to the positive term of the difference. + IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; } + + /// Populate `inequalities` and `equalities` based on the values at(row,col) + /// and at(col,row) of the DBM. Depending on the values being finite and + /// being subsumed by stripe expressions, this may or may not add elements to + /// the lists of equalities and inequalities. + void convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr, + SDBMTermExpr colExpr, + SmallVectorImpl &inequalities, + SmallVectorImpl &equalities); + + /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only + /// adds new inequalities if the inequality is not trivially true. + void convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr, + SmallVectorImpl &inequalities); + + /// Get the total number of elements in the matrix. + unsigned getNumVariables() const { + return 1 + numDims + numSymbols + numTemporaries; + } + + /// Get the position in the matrix that corresponds to the given dimension. + unsigned getDimPosition(unsigned position) const { return 1 + position; } + + /// Get the position in the matrix that corresponds to the given symbol. + unsigned getSymbolPosition(unsigned position) const { + return 1 + numDims + position; + } + + /// Get the position in the matrix that corresponds to the given temporary. + unsigned getTemporaryPosition(unsigned position) const { + return 1 + numDims + numSymbols + position; + } + + /// Number of dimensions in the system, + unsigned numDims; + /// Number of symbols in the system. + unsigned numSymbols; + /// Number of temporary variables in the system. + unsigned numTemporaries; + + /// Difference bounds matrix, stored as a linearized row-major vector. + /// Each value in this matrix corresponds to an inequality + /// + /// v@col - v@row <= at(row, col) + /// + /// where v@col and v@row are the variables that correspond to the linearized + /// position in the matrix. The positions correspond to + /// + /// - constant 0 (producing constraints v@col <= X and -v@row <= Y); + /// - SDBM expression dimensions (d0, d1, ...); + /// - SDBM expression symbols (s0, s1, ...); + /// - temporary variables (t0, t1, ...). + /// + /// Temporary variables are introduced to represent expressions that are not + /// trivially a difference between two variables. For example, if one side of + /// a difference expression is itself a stripe expression, it will be replaced + /// with a temporary variable assigned equal to this expression. + /// + /// Infinite entries in the matrix correspond correspond to an absence of a + /// constraint: + /// + /// v@col - v@row <= infinity + /// + /// is trivially true. Negated values at symmetric positions in the matrix + /// allow one to couple two inequalities into a single equality. + std::vector matrix; + + /// The mapping between the indices of variables in the DBM and the stripe + /// expressions they are equal to. These expressions are stored as they + /// appeared when constructing an SDBM from a SDBMExprs, in particular no + /// temporaries can appear in these expressions. This removes the need to + /// iteratively substitute definitions of the temporaries in the reverse + /// conversion. + DenseMap stripeToPoint; +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_SDBM_SDBM_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..501c66140f026eb13251946f8f7f294b294d4a09 --- /dev/null +++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h @@ -0,0 +1,32 @@ +//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H +#define MLIR_DIALECT_SDBM_SDBMDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/Support/StorageUniquer.h" + +namespace mlir { +class MLIRContext; + +class SDBMDialect : public Dialect { +public: + SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {} + + static StringRef getDialectNamespace() { return "sdbm"; } + + /// Get the uniquer for SDBM expressions. This should not be used directly. + StorageUniquer &getUniquer() { return uniquer; } + +private: + StorageUniquer uniquer; +}; +} // namespace mlir + +#endif // MLIR_DIALECT_SDBM_SDBMDIALECT_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h new file mode 100644 index 0000000000000000000000000000000000000000..84a9a8405a8394576c1eacd46c07c4dec124b8da --- /dev/null +++ b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h @@ -0,0 +1,576 @@ +//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A striped difference-bound matrix (SDBM) expression is a constant expression, +// an identifier, a binary expression with constant RHS and +, stripe operators +// or a difference expression between two identifiers. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H +#define MLIR_DIALECT_SDBM_SDBMEXPR_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfo.h" + +namespace mlir { + +class AffineExpr; +class MLIRContext; + +enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg }; + +namespace detail { +struct SDBMExprStorage; +struct SDBMBinaryExprStorage; +struct SDBMDiffExprStorage; +struct SDBMTermExprStorage; +struct SDBMConstantExprStorage; +struct SDBMNegExprStorage; +} // namespace detail + +class SDBMConstantExpr; +class SDBMDialect; +class SDBMDimExpr; +class SDBMSymbolExpr; +class SDBMTermExpr; + +/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side +/// expression for the SDBM framework. SDBM expressions are a subset of affine +/// expressions supporting low-complexity algorithms for the operations used in +/// loop transformations. In particular, are supported: +/// - constant expressions; +/// - single variables (dimensions and symbols) with +1 or -1 coefficient; +/// - stripe expressions: "x # C", where "x" is a single variable or another +/// stripe expression, "#" is the stripe operator, and "C" is a constant +/// expression; "#" is defined as x - x mod C. +/// - sum expressions between single variable/stripe expressions and constant +/// expressions; +/// - difference expressions between single variable/stripe expressions. +/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing +/// and operating on SDBM expressions. For example, it requires the LHS of a +/// sum expression to be a single variable or a stripe expression. These +/// restrictions are intended to force the caller to perform the necessary +/// simplifications to stay within the SDBM domain, because SDBM expressions do +/// not combine in more cases than they do. This choice may be reconsidered in +/// the future. +/// +/// SDBM expressions are grouped into the following structure +/// - expression +/// - varying +/// - direct +/// - sum <- (term, constant) +/// - term +/// - symbol +/// - dimension +/// - stripe <- (direct, constant) +/// - negation <- (direct) +/// - difference <- (direct, term) +/// - constant +/// The notation <- (...) denotes the types of subexpressions a compound +/// expression can combine. The tree of subexpressions essentially imposes the +/// following canonicalization rules: +/// - constants are always folded; +/// - constants can only appear on the RHS of an expression; +/// - double negation must be elided; +/// - an additive constant term is only allowed in a sum expression, and +/// should be sunk into the nearest such expression in the tree; +/// - zero constant expression can only appear at the top level. +/// +/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by +/// an MLIRContext, and should be used by-value. They are uniqued in the +/// MLIRContext and immortal. +class SDBMExpr { +public: + using ImplType = detail::SDBMExprStorage; + SDBMExpr() : impl(nullptr) {} + /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {} + + /// SDBM expressions are thin wrappers around a unique'ed immutable pointer, + /// which makes them trivially assignable and trivially copyable. + SDBMExpr(const SDBMExpr &) = default; + SDBMExpr &operator=(const SDBMExpr &) = default; + + /// SDBM expressions can be compared straight-forwardly. + bool operator==(const SDBMExpr &other) const { return impl == other.impl; } + bool operator!=(const SDBMExpr &other) const { return !(*this == other); } + + /// SDBM expressions are convertible to `bool`: null expressions are converted + /// to false, non-null expressions are converted to true. + explicit operator bool() const { return impl != nullptr; } + bool operator!() const { return !static_cast(*this); } + + /// Negate the given SDBM expression. + SDBMExpr operator-(); + + /// Prints the SDBM expression. + void print(raw_ostream &os) const; + void dump() const; + + /// LLVM-style casts. + template bool isa() const { return U::isClassFor(*this); } + template U dyn_cast() const { + if (!isa()) + return {}; + return U(const_cast(this)->impl); + } + template U cast() const { + assert(isa() && "cast to incorrect subtype"); + return U(const_cast(this)->impl); + } + + /// Support for LLVM hashing. + ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); } + + /// Returns the kind of the SDBM expression. + SDBMExprKind getKind() const; + + /// Returns the MLIR context in which this expression lives. + MLIRContext *getContext() const; + + /// Returns the SDBM dialect instance. + SDBMDialect *getDialect() const; + + /// Convert the SDBM expression into an Affine expression. This always + /// succeeds because SDBM are a subset of affine. + AffineExpr getAsAffineExpr() const; + + /// Try constructing an SDBM expression from the given affine expression. + /// This may fail if the affine expression is not representable as SDBM, in + /// which case llvm::None is returned. The conversion procedure recognizes + /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B) + /// patterns for the stripe expression. + static Optional tryConvertAffineExpr(AffineExpr affine); + +protected: + ImplType *impl; +}; + +/// SDBM constant expression, wraps a 64-bit integer. +class SDBMConstantExpr : public SDBMExpr { +public: + using ImplType = detail::SDBMConstantExprStorage; + + using SDBMExpr::SDBMExpr; + + /// Obtain or create a constant expression unique'ed in the given dialect + /// (which belongs to a context). + static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Constant; + } + + int64_t getValue() const; +}; + +/// SDBM varying expression can be one of: +/// - input variable expression; +/// - stripe expression; +/// - negation (product with -1) of either of the above. +/// - sum of a varying and a constant expression +/// - difference between varying expressions +class SDBMVaryingExpr : public SDBMExpr { +public: + using ImplType = detail::SDBMExprStorage; + using SDBMExpr::SDBMExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Neg || + expr.getKind() == SDBMExprKind::Stripe || + expr.getKind() == SDBMExprKind::Add || + expr.getKind() == SDBMExprKind::Diff; + } +}; + +/// SDBM direct expression includes exactly one variable (symbol or dimension), +/// which is not negated in the expression. It can be one of: +/// - term expression; +/// - sum expression. +class SDBMDirectExpr : public SDBMVaryingExpr { +public: + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// If this is a sum expression, return its variable part, otherwise return + /// self. + SDBMTermExpr getTerm(); + + /// If this is a sum expression, return its constant part, otherwise return 0. + int64_t getConstant(); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Stripe || + expr.getKind() == SDBMExprKind::Add; + } +}; + +/// SDBM term expression can be one of: +/// - single variable expression; +/// - stripe expression. +/// Stripe expressions are treated as terms since, in the SDBM domain, they are +/// attached to temporary variables and can appear anywhere a variable can. +class SDBMTermExpr : public SDBMDirectExpr { +public: + using SDBMDirectExpr::SDBMDirectExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Stripe; + } +}; + +/// SDBM sum expression. LHS is a term expression and RHS is a constant. +class SDBMSumExpr : public SDBMDirectExpr { +public: + using ImplType = detail::SDBMBinaryExprStorage; + using SDBMDirectExpr::SDBMDirectExpr; + + /// Obtain or create a sum expression unique'ed in the given context. + static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs); + + static bool isClassFor(const SDBMExpr &expr) { + SDBMExprKind kind = expr.getKind(); + return kind == SDBMExprKind::Add; + } + + SDBMTermExpr getLHS() const; + SDBMConstantExpr getRHS() const; +}; + +/// SDBM difference expression. LHS is a direct expression, i.e. it may be a +/// sum of a term and a constant. RHS is a term expression. Thus the +/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as +/// diff(sum(t1, C), t2) +/// and it is possible to extract the constant factor without negating it. +class SDBMDiffExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMDiffExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a difference expression unique'ed in the given context. + static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Diff; + } + + SDBMDirectExpr getLHS() const; + SDBMTermExpr getRHS() const; +}; + +/// SDBM stripe expression "x # C" where "x" is a term expression, "C" is a +/// constant expression and "#" is the stripe operator defined as: +/// x # C = x - x mod C. +class SDBMStripeExpr : public SDBMTermExpr { +public: + using ImplType = detail::SDBMBinaryExprStorage; + using SDBMTermExpr::SDBMTermExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Stripe; + } + + static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor); + + SDBMDirectExpr getLHS() const; + SDBMConstantExpr getStripeFactor() const; +}; + +/// SDBM "input" variable expression can be either a dimension identifier or +/// a symbol identifier. When used to define SDBM functions, dimensions are +/// interpreted as function arguments while symbols are treated as unknown but +/// constant values, hence the name. +class SDBMInputExpr : public SDBMTermExpr { +public: + using ImplType = detail::SDBMTermExprStorage; + using SDBMTermExpr::SDBMTermExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId; + } + + unsigned getPosition() const; +}; + +/// SDBM dimension expression. Dimensions correspond to function arguments +/// when defining functions using SDBM expressions. +class SDBMDimExpr : public SDBMInputExpr { +public: + using ImplType = detail::SDBMTermExprStorage; + using SDBMInputExpr::SDBMInputExpr; + + /// Obtain or create a dimension expression unique'ed in the given dialect + /// (which belongs to a context). + static SDBMDimExpr get(SDBMDialect *dialect, unsigned position); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId; + } +}; + +/// SDBM symbol expression. Symbols correspond to symbolic constants when +/// defining functions using SDBM expressions. +class SDBMSymbolExpr : public SDBMInputExpr { +public: + using ImplType = detail::SDBMTermExprStorage; + using SDBMInputExpr::SDBMInputExpr; + + /// Obtain or create a symbol expression unique'ed in the given dialect (which + /// belongs to a context). + static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::SymbolId; + } +}; + +/// Negation of an SDBM variable expression. Equivalent to multiplying the +/// expression with -1 (SDBM does not support other coefficients that 1 and -1). +class SDBMNegExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMNegExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a negation expression unique'ed in the given context. + static SDBMNegExpr get(SDBMDirectExpr var); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Neg; + } + + SDBMDirectExpr getVar() const; +}; + +/// A visitor class for SDBM expressions. Calls the kind-specific function +/// depending on the kind of expression it visits. +template class SDBMVisitor { +public: + /// Visit the given SDBM expression, dispatching to kind-specific functions. + Result visit(SDBMExpr expr) { + auto *derived = static_cast(this); + switch (expr.getKind()) { + case SDBMExprKind::Add: + case SDBMExprKind::Diff: + case SDBMExprKind::DimId: + case SDBMExprKind::SymbolId: + case SDBMExprKind::Neg: + case SDBMExprKind::Stripe: + return derived->visitVarying(expr.cast()); + case SDBMExprKind::Constant: + return derived->visitConstant(expr.cast()); + } + + llvm_unreachable("unsupported SDBM expression kind"); + } + + /// Traverse the SDBM expression tree calling `visit` on each node + /// in depth-first preorder. + void walkPreorder(SDBMExpr expr) { return walk(expr); } + + /// Traverse the SDBM expression tree calling `visit` on each node in + /// depth-first postorder. + void walkPostorder(SDBMExpr expr) { return walk(expr); } + +protected: + /// Default visitors do nothing. + void visitSum(SDBMSumExpr) {} + void visitDiff(SDBMDiffExpr) {} + void visitStripe(SDBMStripeExpr) {} + void visitDim(SDBMDimExpr) {} + void visitSymbol(SDBMSymbolExpr) {} + void visitNeg(SDBMNegExpr) {} + void visitConstant(SDBMConstantExpr) {} + + /// Default implementation of visitDirect dispatches to the dedicated for sums + /// or delegates to visitTerm for the other expression kinds. Concrete + /// visitors can overload it. + Result visitDirect(SDBMDirectExpr expr) { + auto *derived = static_cast(this); + if (auto sum = expr.dyn_cast()) + return derived->visitSum(sum); + else + return derived->visitTerm(expr.cast()); + } + + /// Default implementation of visitTerm dispatches to the special functions + /// for stripes and other variables. Concrete visitors can override it. + Result visitTerm(SDBMTermExpr expr) { + auto *derived = static_cast(this); + if (expr.getKind() == SDBMExprKind::Stripe) + return derived->visitStripe(expr.cast()); + else + return derived->visitInput(expr.cast()); + } + + /// Default implementation of visitInput dispatches to the special + /// functions for dimensions or symbols. Concrete visitors can override it to + /// visit all variables instead. + Result visitInput(SDBMInputExpr expr) { + auto *derived = static_cast(this); + if (expr.getKind() == SDBMExprKind::DimId) + return derived->visitDim(expr.cast()); + else + return derived->visitSymbol(expr.cast()); + } + + /// Default implementation of visitVarying dispatches to the special + /// functions for variables and negations thereof. Concrete visitors can + /// override it to visit all variables and negations instead. + Result visitVarying(SDBMVaryingExpr expr) { + auto *derived = static_cast(this); + if (auto var = expr.dyn_cast()) + return derived->visitDirect(var); + else if (auto neg = expr.dyn_cast()) + return derived->visitNeg(neg); + else if (auto diff = expr.dyn_cast()) + return derived->visitDiff(diff); + + llvm_unreachable("unhandled subtype of varying SDBM expression"); + } + + template void walk(SDBMExpr expr) { + if (isPreorder) + visit(expr); + if (auto sumExpr = expr.dyn_cast()) { + walk(sumExpr.getLHS()); + walk(sumExpr.getRHS()); + } else if (auto diffExpr = expr.dyn_cast()) { + walk(diffExpr.getLHS()); + walk(diffExpr.getRHS()); + } else if (auto stripeExpr = expr.dyn_cast()) { + walk(stripeExpr.getLHS()); + walk(stripeExpr.getStripeFactor()); + } else if (auto negExpr = expr.dyn_cast()) { + walk(negExpr.getVar()); + } + if (!isPreorder) + visit(expr); + } +}; + +/// Overloaded arithmetic operators for SDBM expressions asserting that their +/// arguments have the proper SDBM expression subtype. Perform canonicalization +/// and constant folding on these expressions. +namespace ops_assertions { + +/// Add two SDBM expressions. At least one of the expressions must be a +/// constant or a negation, but both expressions cannot be negations +/// simultaneously. +SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs); +inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) { + return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs); +} +inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) { + return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs; +} + +/// Subtract an SDBM expression from another SDBM expression. Both expressions +/// must not be difference expressions. +SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs); +inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) { + return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs); +} +inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) { + return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs; +} + +/// Construct a stripe expression from a positive expression and a positive +/// constant stripe factor. +SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor); +inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) { + return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor)); +} +} // namespace ops_assertions + +} // end namespace mlir + +namespace llvm { +// SDBMExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMExpr(static_cast(pointer)); + } + static mlir::SDBMExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMExpr(static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMDirectExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMDirectExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMDirectExpr( + static_cast(pointer)); + } + static mlir::SDBMDirectExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMDirectExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMDirectExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMTermExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMTermExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMTermExpr(static_cast(pointer)); + } + static mlir::SDBMTermExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMTermExpr(static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMTermExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMConstantExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMConstantExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMConstantExpr( + static_cast(pointer)); + } + static mlir::SDBMConstantExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMConstantExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMConstantExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // MLIR_DIALECT_SDBM_SDBMEXPR_H diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fc7180de6cbeca1d97d19ca0e00c0dca0c1607a6 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -0,0 +1,19 @@ +set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td) +mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls) +mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRSPIRVLoweringStructGen) + +add_mlir_dialect(SPIRVOps SPIRVOps) + +set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) +mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) +mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRSPIRVEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) +mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization) +add_public_tablegen_target(MLIRSPIRVSerializationGen) + +set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) +mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils) +add_public_tablegen_target(MLIRSPIRVOpUtilsGen) diff --git a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..329caa2d3aa2cfd6d7f1990a1d3f6fe7a9acc2e7 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h @@ -0,0 +1,71 @@ +//===-- LayoutUtils.h - Decorate composite type with layout information ---===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities used to get alignment and layout information for +// types in SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_SPIRV_LAYOUTUTILS_H_ +#define MLIR_DIALECT_SPIRV_LAYOUTUTILS_H_ + +#include + +namespace mlir { +class Type; +class VectorType; +namespace spirv { +class StructType; +class ArrayType; +} // namespace spirv + +/// According to the Vulkan spec "14.5.4. Offset and Stride Assignment": +/// "There are different alignment requirements depending on the specific +/// resources and on the features enabled on the device." +/// +/// There are 3 types of alignment: scalar, base, extended. +/// See the spec for details. +/// +/// Note: Even if scalar alignment is supported, it is generally more +/// performant to use the base alignment. So here the calculation is based on +/// base alignment. +/// +/// The memory layout must obey the following rules: +/// 1. The Offset decoration of any member must be a multiple of its alignment. +/// 2. Any ArrayStride or MatrixStride decoration must be a multiple of the +/// alignment of the array or matrix as defined above. +/// +/// According to the SPIR-V spec: +/// "The ArrayStride, MatrixStride, and Offset decorations must be large +/// enough to hold the size of the objects they affect (that is, specifying +/// overlap is invalid)." +class VulkanLayoutUtils { +public: + using Size = uint64_t; + + /// Returns a new StructType with layout info. Assigns the type size in bytes + /// to the `size`. Assigns the type alignment in bytes to the `alignment`. + static spirv::StructType decorateType(spirv::StructType structType, + Size &size, Size &alignment); + /// Checks whether a type is legal in terms of Vulkan layout info + /// decoration. A type is dynamically illegal if it's a composite type in the + /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage + /// Classes without layout information. + static bool isLegalType(Type type); + +private: + static Type decorateType(Type type, Size &size, Size &alignment); + static Type decorateType(VectorType vectorType, Size &size, Size &alignment); + static Type decorateType(spirv::ArrayType arrayType, Size &size, + Size &alignment); + /// Calculates the alignment for the given scalar type. + static Size getScalarTypeAlignment(Type scalarType); +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_LAYOUTUTILS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h new file mode 100644 index 0000000000000000000000000000000000000000..68f149b54d57dd498f2d3c4660243cb1db8e8d11 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -0,0 +1,40 @@ +//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_PASSES_H_ +#define MLIR_DIALECT_SPIRV_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace spirv { + +class ModuleOp; +/// Creates a module pass that converts composite types used by objects in the +/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage +/// classes with layout information. +/// Right now this pass only supports Vulkan layout rules. +std::unique_ptr> +createDecorateSPIRVCompositeTypeLayoutPass(); + +/// Creates a module pass that lowers the ABI attributes specified during SPIR-V +/// Lowering. Specifically, +/// 1) Creates the global variables for arguments of entry point function using +/// the specification in the ABI attributes for each argument. +/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point +/// functions using the specification in the EntryPointAttr. +std::unique_ptr> createLowerABIAttributesPass(); + +} // namespace spirv +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td new file mode 100644 index 0000000000000000000000000000000000000000..39858f357ff17fc7923851848f78c77d8b6f0d93 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -0,0 +1,537 @@ +//===-- SPIRVArithmeticOps.td - MLIR SPIR-V Arithmetic Ops -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains arithmetic ops for the SPIR-V dialect. It corresponds +// to "3.32.13. Arithmetic Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_ARITHMETIC_OPS +#define SPIRV_ARITHMETIC_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +class SPV_ArithmeticBinaryOp traits = []> : + // Operands type same as result type. + SPV_BinaryOp; + +class SPV_ArithmeticUnaryOp traits = []> : + // Operand type same as result type. + SPV_UnaryOp; + +// ----- + +def SPV_FAddOp : SPV_ArithmeticBinaryOp<"FAdd", SPV_Float, [Commutative]> { + let summary = "Floating-point addition of Operand 1 and Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. + + ### Custom assembly form + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fadd-op ::= ssa-id `=` `spv.FAdd` ssa-use, ssa-use + `:` float-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.FAdd %0, %1 : f32 + %5 = spv.FAdd %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FDivOp : SPV_ArithmeticBinaryOp<"FDiv", SPV_Float, []> { + let summary = "Floating-point division of Operand 1 divided by Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. + + ### Custom assembly form + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fdiv-op ::= ssa-id `=` `spv.FDiv` ssa-use, ssa-use + `:` float-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.FDiv %0, %1 : f32 + %5 = spv.FDiv %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FModOp : SPV_ArithmeticBinaryOp<"FMod", SPV_Float, []> { + let summary = [{ + The floating-point remainder whose sign matches the sign of Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. Otherwise, the result is the remainder r of Operand + 1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the + sign of Operand 2. + + ### Custom assembly form + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmod-op ::= ssa-id `=` `spv.FMod` ssa-use, ssa-use + `:` float-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.FMod %0, %1 : f32 + %5 = spv.FMod %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FMulOp : SPV_ArithmeticBinaryOp<"FMul", SPV_Float, [Commutative]> { + let summary = "Floating-point multiplication of Operand 1 and Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmul-op ::= `spv.FMul` ssa-use, ssa-use + `:` float-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.FMul %0, %1 : f32 + %5 = spv.FMul %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FNegateOp : SPV_ArithmeticUnaryOp<"FNegate", SPV_Float, []> { + let summary = "Floating-point subtract of Operand from zero."; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The type of Operand must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmul-op ::= `spv.FNegate` ssa-use `:` float-scalar-vector-type + ``` + + For example: + + ``` + %1 = spv.FNegate %0 : f32 + %3 = spv.FNegate %2 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FRemOp : SPV_ArithmeticBinaryOp<"FRem", SPV_Float, []> { + let summary = [{ + The floating-point remainder whose sign matches the sign of Operand 1. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. Otherwise, the result is the remainder r of Operand + 1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the + sign of Operand 1. + + ### Custom assembly form + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + frem-op ::= ssa-id `=` `spv.FRemOp` ssa-use, ssa-use + `:` float-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.FRemOp %0, %1 : f32 + %5 = spv.FRemOp %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FSubOp : SPV_ArithmeticBinaryOp<"FSub", SPV_Float, []> { + let summary = "Floating-point subtraction of Operand 2 from Operand 1."; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. + + ### Custom assembly form + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fsub-op ::= ssa-id `=` `spv.FRemOp` ssa-use, ssa-use + `:` float-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.FRemOp %0, %1 : f32 + %5 = spv.FRemOp %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> { + let summary = "Integer addition of Operand 1 and Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same number of components as Result + Type. They must have the same component width as Result Type. + + The resulting value will equal the low-order N bits of the correct + result R, where N is the component width and R is computed with enough + precision to avoid overflow and underflow. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + iadd-op ::= ssa-id `=` `spv.IAdd` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.IAdd %0, %1 : i32 + %5 = spv.IAdd %2, %3 : vector<4xi32> + + ``` + }]; + + let hasFolder = 1; +} + +// ----- + +def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> { + let summary = "Integer multiplication of Operand 1 and Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same number of components as Result + Type. They must have the same component width as Result Type. + + The resulting value will equal the low-order N bits of the correct + result R, where N is the component width and R is computed with enough + precision to avoid overflow and underflow. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + imul-op ::= ssa-id `=` `spv.IMul` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.IMul %0, %1 : i32 + %5 = spv.IMul %2, %3 : vector<4xi32> + + ``` + }]; + + let hasFolder = 1; +} + +// ----- + +def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub", SPV_Integer, []> { + let summary = "Integer subtraction of Operand 2 from Operand 1."; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same number of components as Result + Type. They must have the same component width as Result Type. + + The resulting value will equal the low-order N bits of the correct + result R, where N is the component width and R is computed with enough + precision to avoid overflow and underflow. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + isub-op ::= `spv.ISub` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.ISub %0, %1 : i32 + %5 = spv.ISub %2, %3 : vector<4xi32> + + ``` + }]; + + let hasFolder = 1; +} + +// ----- + +def SPV_SDivOp : SPV_ArithmeticBinaryOp<"SDiv", SPV_Integer, []> { + let summary = "Signed-integer division of Operand 1 divided by Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same number of components as Result + Type. They must have the same component width as Result Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + sdiv-op ::= ssa-id `=` `spv.SDiv` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %4 = spv.SDiv %0, %1 : i32 + %5 = spv.SDiv %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_SModOp : SPV_ArithmeticBinaryOp<"SMod", SPV_Integer, []> { + let summary = [{ + Signed remainder operation for the remainder whose sign matches the sign + of Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same number of components as Result + Type. They must have the same component width as Result Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. Otherwise, the result is the remainder r of Operand + 1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the + sign of Operand 2. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smod-op ::= ssa-id `=` `spv.SMod` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.SMod %0, %1 : i32 + %5 = spv.SMod %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> { + let summary = [{ + Signed remainder operation for the remainder whose sign matches the sign + of Operand 1. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same number of components as Result + Type. They must have the same component width as Result Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. Otherwise, the result is the remainder r of Operand + 1 divided by Operand 2 where if r ≠ 0, the sign of r is the same as the + sign of Operand 1. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + srem-op ::= ssa-id `=` `spv.SRem` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.SRem %0, %1 : i32 + %5 = spv.SRem %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> { + let summary = "Unsigned-integer division of Operand 1 divided by Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of integer type, whose Signedness + operand is 0. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + udiv-op ::= ssa-id `=` `spv.UDiv` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.UDiv %0, %1 : i32 + %5 = spv.UDiv %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer> { + let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2."; + + let description = [{ + Result Type must be a scalar or vector of integer type, whose Signedness + operand is 0. + + The types of Operand 1 and Operand 2 both must be the same as Result + Type. + + Results are computed per component. The resulting value is undefined + if Operand 2 is 0. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + umod-op ::= ssa-id `=` `spv.UMod` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.UMod %0, %1 : i32 + %5 = spv.UMod %2, %3 : vector<4xi32> + + ``` + }]; +} + +#endif // SPIRV_ARITHMETIC_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td new file mode 100644 index 0000000000000000000000000000000000000000..c2ea100c12162798535887f5198ac69b929bd8e8 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td @@ -0,0 +1,552 @@ +//===-- SPIRVAtomicOps.td - MLIR SPIR-V Atomic Ops ---------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains atomic ops for the SPIR-V dialect. It corresponds to +// "3.32.18. Atomic Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_ATOMIC_OPS +#define SPIRV_ATOMIC_OPS + +class SPV_AtomicUpdateOp traits = []> : + SPV_Op { + let parser = [{ return ::parseAtomicUpdateOp(parser, result, false); }]; + let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; + let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$semantics + ); + let results = (outs + SPV_Integer:$result + ); +} + +class SPV_AtomicUpdateWithValueOp traits = []> : + SPV_Op { + let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }]; + let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; + let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$semantics, + SPV_Integer:$value + ); + let results = (outs + SPV_Integer:$result + ); +} + +// ----- + +def SPV_AtomicAndOp : SPV_AtomicUpdateWithValueOp<"AtomicAnd", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by the bitwise AND of Original Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + scope ::= `"CrossDevice"` | `"Device"` | `"Workgroup"` | ... + + memory-semantics ::= `"None"` | `"Acquire"` | "Release"` | ... + + atomic-and-op ::= + `spv.AtomicAnd` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicAnd "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { + let summary = "Deprecated (use OpAtomicCompareExchange)."; + + let description = [{ + Has the same semantics as OpAtomicCompareExchange. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-compare-exchange-weak-op ::= + `spv.AtomicCompareExchangeWeak` scope memory-semantics memory-semantics + ssa-use `,` ssa-use `,` ssa-use + `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicCompareExchangeWeak "Workgroup" "Acquire" "None" + %pointer, %value, %comparator + : !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$equal_semantics, + SPV_MemorySemanticsAttr:$unequal_semantics, + SPV_Integer:$value, + SPV_Integer:$comparator + ); + + let results = (outs + SPV_Integer:$result + ); +} + +// ----- + +def SPV_AtomicIAddOp : SPV_AtomicUpdateWithValueOp<"AtomicIAdd", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by integer addition of Original Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-iadd-op ::= + `spv.AtomicIAdd` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicIAdd "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicIDecrementOp : SPV_AtomicUpdateOp<"AtomicIDecrement", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value through integer subtraction of 1 from Original Value, + and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. The type of the value + pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-idecrement-op ::= + `spv.AtomicIDecrement` scope memory-semantics ssa-use + `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicIDecrement "Device" "None" %pointer : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicIIncrementOp : SPV_AtomicUpdateOp<"AtomicIIncrement", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value through integer addition of 1 to Original Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. The type of the value + pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-iincrement-op ::= + `spv.AtomicIIncrement` scope memory-semantics ssa-use + `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicIncrement "Device" "None" %pointer : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicISubOp : SPV_AtomicUpdateWithValueOp<"AtomicISub", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by integer subtraction of Value from Original Value, + and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-isub-op ::= + `spv.AtomicISub` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicISub "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicOrOp : SPV_AtomicUpdateWithValueOp<"AtomicOr", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by the bitwise OR of Original Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-or-op ::= + `spv.AtomicOr` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicOr "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicSMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicSMax", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the largest signed integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-smax-op ::= + `spv.AtomicSMax` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicSMax "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicSMinOp : SPV_AtomicUpdateWithValueOp<"AtomicSMin", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the smallest signed integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-smin-op ::= + `spv.AtomicSMin` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicSMin "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the largest unsigned integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-umax-op ::= + `spv.AtomicUMax` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicUMax "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the smallest unsigned integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-umin-op ::= + `spv.AtomicUMin` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicUMin "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicXorOp : SPV_AtomicUpdateWithValueOp<"AtomicXor", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by the bitwise exclusive OR of Original Value and + Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-xor-op ::= + `spv.AtomicXor` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicXor "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +#endif // SPIRV_ATOMIC_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td new file mode 100644 index 0000000000000000000000000000000000000000..5751a32e1695ca3278f94267b4f2fc0dbd13190b --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -0,0 +1,1319 @@ +//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base file for SPIR-V operation definition specification. +// This file defines the SPIR-V dialect, common SPIR-V types, and utilities +// for facilitating defining SPIR-V ops. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_BASE +#define SPIRV_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// SPIR-V dialect definitions +//===----------------------------------------------------------------------===// + +def SPV_Dialect : Dialect { + let name = "spv"; + + let description = [{ + The SPIR-V dialect in MLIR. + + SPIR-V is the Khronos Group's binary intermediate language for representing + graphical-shader stages and compute kernels for multiple Khronos APIs, + including OpenCL, OpenGL, and Vulkan. + See https://www.khronos.org/registry/spir-v for more details. + + This dialect aims to be a simple proxy for the SPIR-V binary format to + enable straightforward and lightweight conversion from/to the binary + format. Ops in this dialect should stay at the same semantic level and + try to be a mechanical mapping to the corresponding SPIR-V instructions; + but they may deviate representationally to allow using MLIR mechanisms. + As a convention, if such deviation happens, the op name follows "snake_case" + style; otherwise, the op name just follows the SPIR-V mnemonic (by removing + the leading `Op` prefix) to use "CamelCase" style. + }]; + + let cppNamespace = "spirv"; +} + +//===----------------------------------------------------------------------===// +// SPIR-V extension definitions +//===----------------------------------------------------------------------===// + +// Extensions known to the SPIR-V dialect. +// https://github.com/KhronosGroup/SPIRV-Registry has the full list. +def SPV_KHR_16bit_storage : StrEnumAttrCase<"SPV_KHR_16bit_storage">; +def SPV_KHR_8bit_storage : StrEnumAttrCase<"SPV_KHR_8bit_storage">; +def SPV_KHR_device_group : StrEnumAttrCase<"SPV_KHR_device_group">; +def SPV_KHR_float_controls : StrEnumAttrCase<"SPV_KHR_float_controls">; +def SPV_KHR_physical_storage_buffer : StrEnumAttrCase<"SPV_KHR_physical_storage_buffer">; +def SPV_KHR_multiview : StrEnumAttrCase<"SPV_KHR_multiview">; +def SPV_KHR_no_integer_wrap_decoration : StrEnumAttrCase<"SPV_KHR_no_integer_wrap_decoration">; +def SPV_KHR_post_depth_coverage : StrEnumAttrCase<"SPV_KHR_post_depth_coverage">; +def SPV_KHR_shader_atomic_counter_ops : StrEnumAttrCase<"SPV_KHR_shader_atomic_counter_ops">; +def SPV_KHR_shader_ballot : StrEnumAttrCase<"SPV_KHR_shader_ballot">; +def SPV_KHR_shader_draw_parameters : StrEnumAttrCase<"SPV_KHR_shader_draw_parameters">; +def SPV_KHR_storage_buffer_storage_class : StrEnumAttrCase<"SPV_KHR_storage_buffer_storage_class">; +def SPV_KHR_subgroup_vote : StrEnumAttrCase<"SPV_KHR_subgroup_vote">; +def SPV_KHR_variable_pointers : StrEnumAttrCase<"SPV_KHR_variable_pointers">; +def SPV_KHR_vulkan_memory_model : StrEnumAttrCase<"SPV_KHR_vulkan_memory_model">; + +def SPV_EXT_fragment_fully_covered : StrEnumAttrCase<"SPV_EXT_fragment_fully_covered">; +def SPV_EXT_fragment_invocation_density : StrEnumAttrCase<"SPV_EXT_fragment_invocation_density">; +def SPV_EXT_fragment_shader_interlock : StrEnumAttrCase<"SPV_EXT_fragment_shader_interlock">; +def SPV_EXT_physical_storage_buffer : StrEnumAttrCase<"SPV_EXT_physical_storage_buffer">; +def SPV_EXT_shader_stencil_export : StrEnumAttrCase<"SPV_EXT_shader_stencil_export">; + +def SPV_AMD_shader_explicit_vertex_parameter : StrEnumAttrCase<"SPV_AMD_shader_explicit_vertex_parameter">; + +def SPV_GOOGLE_user_type : StrEnumAttrCase<"SPV_GOOGLE_user_type">; + +def SPV_NV_compute_shader_derivatives : StrEnumAttrCase<"SPV_NV_compute_shader_derivatives">; +def SPV_NV_fragment_shader_barycentric : StrEnumAttrCase<"SPV_NV_fragment_shader_barycentric">; +def SPV_NV_geometry_shader_passthrough : StrEnumAttrCase<"SPV_NV_geometry_shader_passthrough">; +def SPV_NV_mesh_shader : StrEnumAttrCase<"SPV_NV_mesh_shader">; +def SPV_NV_ray_tracing : StrEnumAttrCase<"SPV_NV_ray_tracing">; +def SPV_NV_sample_mask_override_coverage : StrEnumAttrCase<"SPV_NV_sample_mask_override_coverage">; +def SPV_NV_shader_sm_builtins : StrEnumAttrCase<"SPV_NV_shader_sm_builtins">; +def SPV_NV_shading_rate : StrEnumAttrCase<"SPV_NV_shading_rate">; +def SPV_NV_stereo_view_rendering : StrEnumAttrCase<"SPV_NV_stereo_view_rendering">; +def SPV_NV_viewport_array2 : StrEnumAttrCase<"SPV_NV_viewport_array2">; + +def SPV_NVX_multiview_per_view_attributes : StrEnumAttrCase<"SPV_NVX_multiview_per_view_attributes">; + +def SPV_ExtensionAttr : + StrEnumAttr<"Extension", "supported SPIR-V extensions", [ + SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group, + SPV_KHR_float_controls, SPV_KHR_physical_storage_buffer, SPV_KHR_multiview, + SPV_KHR_no_integer_wrap_decoration, SPV_KHR_post_depth_coverage, + SPV_KHR_shader_atomic_counter_ops, SPV_KHR_shader_ballot, + SPV_KHR_shader_draw_parameters, SPV_KHR_storage_buffer_storage_class, + SPV_KHR_subgroup_vote, SPV_KHR_variable_pointers, + SPV_KHR_vulkan_memory_model, SPV_EXT_fragment_fully_covered, + SPV_EXT_fragment_invocation_density, SPV_EXT_fragment_shader_interlock, + SPV_EXT_physical_storage_buffer, SPV_EXT_shader_stencil_export, + SPV_AMD_shader_explicit_vertex_parameter, SPV_GOOGLE_user_type, + SPV_NV_compute_shader_derivatives, SPV_NV_fragment_shader_barycentric, + SPV_NV_geometry_shader_passthrough, SPV_NV_mesh_shader, SPV_NV_ray_tracing, + SPV_NV_sample_mask_override_coverage, SPV_NV_shader_sm_builtins, + SPV_NV_shading_rate, SPV_NV_stereo_view_rendering, + SPV_NV_viewport_array2, SPV_NVX_multiview_per_view_attributes, + ]> { + let cppNamespace = "::mlir::spirv"; +} + +//===----------------------------------------------------------------------===// +// SPIR-V enum definitions +//===----------------------------------------------------------------------===// + +// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY! + +def SPV_C_Matrix : I32EnumAttrCase<"Matrix", 0>; +def SPV_C_Shader : I32EnumAttrCase<"Shader", 1>; +def SPV_C_Geometry : I32EnumAttrCase<"Geometry", 2>; +def SPV_C_Tessellation : I32EnumAttrCase<"Tessellation", 3>; +def SPV_C_Addresses : I32EnumAttrCase<"Addresses", 4>; +def SPV_C_Linkage : I32EnumAttrCase<"Linkage", 5>; +def SPV_C_Kernel : I32EnumAttrCase<"Kernel", 6>; +def SPV_C_Vector16 : I32EnumAttrCase<"Vector16", 7>; +def SPV_C_Float16Buffer : I32EnumAttrCase<"Float16Buffer", 8>; +def SPV_C_Float16 : I32EnumAttrCase<"Float16", 9>; +def SPV_C_Float64 : I32EnumAttrCase<"Float64", 10>; +def SPV_C_Int64 : I32EnumAttrCase<"Int64", 11>; +def SPV_C_Int64Atomics : I32EnumAttrCase<"Int64Atomics", 12>; +def SPV_C_ImageBasic : I32EnumAttrCase<"ImageBasic", 13>; +def SPV_C_ImageReadWrite : I32EnumAttrCase<"ImageReadWrite", 14>; +def SPV_C_ImageMipmap : I32EnumAttrCase<"ImageMipmap", 15>; +def SPV_C_Pipes : I32EnumAttrCase<"Pipes", 17>; +def SPV_C_Groups : I32EnumAttrCase<"Groups", 18>; +def SPV_C_DeviceEnqueue : I32EnumAttrCase<"DeviceEnqueue", 19>; +def SPV_C_LiteralSampler : I32EnumAttrCase<"LiteralSampler", 20>; +def SPV_C_AtomicStorage : I32EnumAttrCase<"AtomicStorage", 21>; +def SPV_C_Int16 : I32EnumAttrCase<"Int16", 22>; +def SPV_C_TessellationPointSize : I32EnumAttrCase<"TessellationPointSize", 23>; +def SPV_C_GeometryPointSize : I32EnumAttrCase<"GeometryPointSize", 24>; +def SPV_C_ImageGatherExtended : I32EnumAttrCase<"ImageGatherExtended", 25>; +def SPV_C_StorageImageMultisample : I32EnumAttrCase<"StorageImageMultisample", 27>; +def SPV_C_UniformBufferArrayDynamicIndexing : I32EnumAttrCase<"UniformBufferArrayDynamicIndexing", 28>; +def SPV_C_SampledImageArrayDynamicIndexing : I32EnumAttrCase<"SampledImageArrayDynamicIndexing", 29>; +def SPV_C_StorageBufferArrayDynamicIndexing : I32EnumAttrCase<"StorageBufferArrayDynamicIndexing", 30>; +def SPV_C_StorageImageArrayDynamicIndexing : I32EnumAttrCase<"StorageImageArrayDynamicIndexing", 31>; +def SPV_C_ClipDistance : I32EnumAttrCase<"ClipDistance", 32>; +def SPV_C_CullDistance : I32EnumAttrCase<"CullDistance", 33>; +def SPV_C_ImageCubeArray : I32EnumAttrCase<"ImageCubeArray", 34>; +def SPV_C_SampleRateShading : I32EnumAttrCase<"SampleRateShading", 35>; +def SPV_C_ImageRect : I32EnumAttrCase<"ImageRect", 36>; +def SPV_C_SampledRect : I32EnumAttrCase<"SampledRect", 37>; +def SPV_C_GenericPointer : I32EnumAttrCase<"GenericPointer", 38>; +def SPV_C_Int8 : I32EnumAttrCase<"Int8", 39>; +def SPV_C_InputAttachment : I32EnumAttrCase<"InputAttachment", 40>; +def SPV_C_SparseResidency : I32EnumAttrCase<"SparseResidency", 41>; +def SPV_C_MinLod : I32EnumAttrCase<"MinLod", 42>; +def SPV_C_Sampled1D : I32EnumAttrCase<"Sampled1D", 43>; +def SPV_C_Image1D : I32EnumAttrCase<"Image1D", 44>; +def SPV_C_SampledCubeArray : I32EnumAttrCase<"SampledCubeArray", 45>; +def SPV_C_SampledBuffer : I32EnumAttrCase<"SampledBuffer", 46>; +def SPV_C_ImageBuffer : I32EnumAttrCase<"ImageBuffer", 47>; +def SPV_C_ImageMSArray : I32EnumAttrCase<"ImageMSArray", 48>; +def SPV_C_StorageImageExtendedFormats : I32EnumAttrCase<"StorageImageExtendedFormats", 49>; +def SPV_C_ImageQuery : I32EnumAttrCase<"ImageQuery", 50>; +def SPV_C_DerivativeControl : I32EnumAttrCase<"DerivativeControl", 51>; +def SPV_C_InterpolationFunction : I32EnumAttrCase<"InterpolationFunction", 52>; +def SPV_C_TransformFeedback : I32EnumAttrCase<"TransformFeedback", 53>; +def SPV_C_GeometryStreams : I32EnumAttrCase<"GeometryStreams", 54>; +def SPV_C_StorageImageReadWithoutFormat : I32EnumAttrCase<"StorageImageReadWithoutFormat", 55>; +def SPV_C_StorageImageWriteWithoutFormat : I32EnumAttrCase<"StorageImageWriteWithoutFormat", 56>; +def SPV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57>; +def SPV_C_SubgroupDispatch : I32EnumAttrCase<"SubgroupDispatch", 58>; +def SPV_C_NamedBarrier : I32EnumAttrCase<"NamedBarrier", 59>; +def SPV_C_PipeStorage : I32EnumAttrCase<"PipeStorage", 60>; +def SPV_C_GroupNonUniform : I32EnumAttrCase<"GroupNonUniform", 61>; +def SPV_C_GroupNonUniformVote : I32EnumAttrCase<"GroupNonUniformVote", 62>; +def SPV_C_GroupNonUniformArithmetic : I32EnumAttrCase<"GroupNonUniformArithmetic", 63>; +def SPV_C_GroupNonUniformBallot : I32EnumAttrCase<"GroupNonUniformBallot", 64>; +def SPV_C_GroupNonUniformShuffle : I32EnumAttrCase<"GroupNonUniformShuffle", 65>; +def SPV_C_GroupNonUniformShuffleRelative : I32EnumAttrCase<"GroupNonUniformShuffleRelative", 66>; +def SPV_C_GroupNonUniformClustered : I32EnumAttrCase<"GroupNonUniformClustered", 67>; +def SPV_C_GroupNonUniformQuad : I32EnumAttrCase<"GroupNonUniformQuad", 68>; +def SPV_C_ShaderLayer : I32EnumAttrCase<"ShaderLayer", 69>; +def SPV_C_ShaderViewportIndex : I32EnumAttrCase<"ShaderViewportIndex", 70>; +def SPV_C_SubgroupBallotKHR : I32EnumAttrCase<"SubgroupBallotKHR", 4423>; +def SPV_C_DrawParameters : I32EnumAttrCase<"DrawParameters", 4427>; +def SPV_C_SubgroupVoteKHR : I32EnumAttrCase<"SubgroupVoteKHR", 4431>; +def SPV_C_StorageBuffer16BitAccess : I32EnumAttrCase<"StorageBuffer16BitAccess", 4433>; +def SPV_C_StorageUniform16 : I32EnumAttrCase<"StorageUniform16", 4434>; +def SPV_C_StoragePushConstant16 : I32EnumAttrCase<"StoragePushConstant16", 4435>; +def SPV_C_StorageInputOutput16 : I32EnumAttrCase<"StorageInputOutput16", 4436>; +def SPV_C_DeviceGroup : I32EnumAttrCase<"DeviceGroup", 4437>; +def SPV_C_MultiView : I32EnumAttrCase<"MultiView", 4439>; +def SPV_C_VariablePointersStorageBuffer : I32EnumAttrCase<"VariablePointersStorageBuffer", 4441>; +def SPV_C_VariablePointers : I32EnumAttrCase<"VariablePointers", 4442>; +def SPV_C_AtomicStorageOps : I32EnumAttrCase<"AtomicStorageOps", 4445>; +def SPV_C_SampleMaskPostDepthCoverage : I32EnumAttrCase<"SampleMaskPostDepthCoverage", 4447>; +def SPV_C_StorageBuffer8BitAccess : I32EnumAttrCase<"StorageBuffer8BitAccess", 4448>; +def SPV_C_UniformAndStorageBuffer8BitAccess : I32EnumAttrCase<"UniformAndStorageBuffer8BitAccess", 4449>; +def SPV_C_StoragePushConstant8 : I32EnumAttrCase<"StoragePushConstant8", 4450>; +def SPV_C_DenormPreserve : I32EnumAttrCase<"DenormPreserve", 4464>; +def SPV_C_DenormFlushToZero : I32EnumAttrCase<"DenormFlushToZero", 4465>; +def SPV_C_SignedZeroInfNanPreserve : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4466>; +def SPV_C_RoundingModeRTE : I32EnumAttrCase<"RoundingModeRTE", 4467>; +def SPV_C_RoundingModeRTZ : I32EnumAttrCase<"RoundingModeRTZ", 4468>; +def SPV_C_Float16ImageAMD : I32EnumAttrCase<"Float16ImageAMD", 5008>; +def SPV_C_ImageGatherBiasLodAMD : I32EnumAttrCase<"ImageGatherBiasLodAMD", 5009>; +def SPV_C_FragmentMaskAMD : I32EnumAttrCase<"FragmentMaskAMD", 5010>; +def SPV_C_StencilExportEXT : I32EnumAttrCase<"StencilExportEXT", 5013>; +def SPV_C_ImageReadWriteLodAMD : I32EnumAttrCase<"ImageReadWriteLodAMD", 5015>; +def SPV_C_ShaderClockKHR : I32EnumAttrCase<"ShaderClockKHR", 5055>; +def SPV_C_SampleMaskOverrideCoverageNV : I32EnumAttrCase<"SampleMaskOverrideCoverageNV", 5249>; +def SPV_C_GeometryShaderPassthroughNV : I32EnumAttrCase<"GeometryShaderPassthroughNV", 5251>; +def SPV_C_ShaderViewportIndexLayerEXT : I32EnumAttrCase<"ShaderViewportIndexLayerEXT", 5254>; +def SPV_C_ShaderViewportMaskNV : I32EnumAttrCase<"ShaderViewportMaskNV", 5255>; +def SPV_C_ShaderStereoViewNV : I32EnumAttrCase<"ShaderStereoViewNV", 5259>; +def SPV_C_PerViewAttributesNV : I32EnumAttrCase<"PerViewAttributesNV", 5260>; +def SPV_C_FragmentFullyCoveredEXT : I32EnumAttrCase<"FragmentFullyCoveredEXT", 5265>; +def SPV_C_MeshShadingNV : I32EnumAttrCase<"MeshShadingNV", 5266>; +def SPV_C_ImageFootprintNV : I32EnumAttrCase<"ImageFootprintNV", 5282>; +def SPV_C_FragmentBarycentricNV : I32EnumAttrCase<"FragmentBarycentricNV", 5284>; +def SPV_C_ComputeDerivativeGroupQuadsNV : I32EnumAttrCase<"ComputeDerivativeGroupQuadsNV", 5288>; +def SPV_C_FragmentDensityEXT : I32EnumAttrCase<"FragmentDensityEXT", 5291>; +def SPV_C_GroupNonUniformPartitionedNV : I32EnumAttrCase<"GroupNonUniformPartitionedNV", 5297>; +def SPV_C_ShaderNonUniform : I32EnumAttrCase<"ShaderNonUniform", 5301>; +def SPV_C_RuntimeDescriptorArray : I32EnumAttrCase<"RuntimeDescriptorArray", 5302>; +def SPV_C_InputAttachmentArrayDynamicIndexing : I32EnumAttrCase<"InputAttachmentArrayDynamicIndexing", 5303>; +def SPV_C_UniformTexelBufferArrayDynamicIndexing : I32EnumAttrCase<"UniformTexelBufferArrayDynamicIndexing", 5304>; +def SPV_C_StorageTexelBufferArrayDynamicIndexing : I32EnumAttrCase<"StorageTexelBufferArrayDynamicIndexing", 5305>; +def SPV_C_UniformBufferArrayNonUniformIndexing : I32EnumAttrCase<"UniformBufferArrayNonUniformIndexing", 5306>; +def SPV_C_SampledImageArrayNonUniformIndexing : I32EnumAttrCase<"SampledImageArrayNonUniformIndexing", 5307>; +def SPV_C_StorageBufferArrayNonUniformIndexing : I32EnumAttrCase<"StorageBufferArrayNonUniformIndexing", 5308>; +def SPV_C_StorageImageArrayNonUniformIndexing : I32EnumAttrCase<"StorageImageArrayNonUniformIndexing", 5309>; +def SPV_C_InputAttachmentArrayNonUniformIndexing : I32EnumAttrCase<"InputAttachmentArrayNonUniformIndexing", 5310>; +def SPV_C_UniformTexelBufferArrayNonUniformIndexing : I32EnumAttrCase<"UniformTexelBufferArrayNonUniformIndexing", 5311>; +def SPV_C_StorageTexelBufferArrayNonUniformIndexing : I32EnumAttrCase<"StorageTexelBufferArrayNonUniformIndexing", 5312>; +def SPV_C_RayTracingNV : I32EnumAttrCase<"RayTracingNV", 5340>; +def SPV_C_VulkanMemoryModel : I32EnumAttrCase<"VulkanMemoryModel", 5345>; +def SPV_C_VulkanMemoryModelDeviceScope : I32EnumAttrCase<"VulkanMemoryModelDeviceScope", 5346>; +def SPV_C_PhysicalStorageBufferAddresses : I32EnumAttrCase<"PhysicalStorageBufferAddresses", 5347>; +def SPV_C_ComputeDerivativeGroupLinearNV : I32EnumAttrCase<"ComputeDerivativeGroupLinearNV", 5350>; +def SPV_C_CooperativeMatrixNV : I32EnumAttrCase<"CooperativeMatrixNV", 5357>; +def SPV_C_FragmentShaderSampleInterlockEXT : I32EnumAttrCase<"FragmentShaderSampleInterlockEXT", 5363>; +def SPV_C_FragmentShaderShadingRateInterlockEXT : I32EnumAttrCase<"FragmentShaderShadingRateInterlockEXT", 5372>; +def SPV_C_ShaderSMBuiltinsNV : I32EnumAttrCase<"ShaderSMBuiltinsNV", 5373>; +def SPV_C_FragmentShaderPixelInterlockEXT : I32EnumAttrCase<"FragmentShaderPixelInterlockEXT", 5378>; +def SPV_C_DemoteToHelperInvocationEXT : I32EnumAttrCase<"DemoteToHelperInvocationEXT", 5379>; +def SPV_C_SubgroupShuffleINTEL : I32EnumAttrCase<"SubgroupShuffleINTEL", 5568>; +def SPV_C_SubgroupBufferBlockIOINTEL : I32EnumAttrCase<"SubgroupBufferBlockIOINTEL", 5569>; +def SPV_C_SubgroupImageBlockIOINTEL : I32EnumAttrCase<"SubgroupImageBlockIOINTEL", 5570>; +def SPV_C_SubgroupImageMediaBlockIOINTEL : I32EnumAttrCase<"SubgroupImageMediaBlockIOINTEL", 5579>; +def SPV_C_IntegerFunctions2INTEL : I32EnumAttrCase<"IntegerFunctions2INTEL", 5584>; +def SPV_C_SubgroupAvcMotionEstimationINTEL : I32EnumAttrCase<"SubgroupAvcMotionEstimationINTEL", 5696>; +def SPV_C_SubgroupAvcMotionEstimationIntraINTEL : I32EnumAttrCase<"SubgroupAvcMotionEstimationIntraINTEL", 5697>; +def SPV_C_SubgroupAvcMotionEstimationChromaINTEL : I32EnumAttrCase<"SubgroupAvcMotionEstimationChromaINTEL", 5698>; + +def SPV_CapabilityAttr : + I32EnumAttr<"Capability", "valid SPIR-V Capability", [ + SPV_C_Matrix, SPV_C_Shader, SPV_C_Geometry, SPV_C_Tessellation, + SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Vector16, + SPV_C_Float16Buffer, SPV_C_Float16, SPV_C_Float64, SPV_C_Int64, + SPV_C_Int64Atomics, SPV_C_ImageBasic, SPV_C_ImageReadWrite, SPV_C_ImageMipmap, + SPV_C_Pipes, SPV_C_Groups, SPV_C_DeviceEnqueue, SPV_C_LiteralSampler, + SPV_C_AtomicStorage, SPV_C_Int16, SPV_C_TessellationPointSize, + SPV_C_GeometryPointSize, SPV_C_ImageGatherExtended, + SPV_C_StorageImageMultisample, SPV_C_UniformBufferArrayDynamicIndexing, + SPV_C_SampledImageArrayDynamicIndexing, + SPV_C_StorageBufferArrayDynamicIndexing, + SPV_C_StorageImageArrayDynamicIndexing, SPV_C_ClipDistance, SPV_C_CullDistance, + SPV_C_ImageCubeArray, SPV_C_SampleRateShading, SPV_C_ImageRect, + SPV_C_SampledRect, SPV_C_GenericPointer, SPV_C_Int8, SPV_C_InputAttachment, + SPV_C_SparseResidency, SPV_C_MinLod, SPV_C_Sampled1D, SPV_C_Image1D, + SPV_C_SampledCubeArray, SPV_C_SampledBuffer, SPV_C_ImageBuffer, + SPV_C_ImageMSArray, SPV_C_StorageImageExtendedFormats, SPV_C_ImageQuery, + SPV_C_DerivativeControl, SPV_C_InterpolationFunction, SPV_C_TransformFeedback, + SPV_C_GeometryStreams, SPV_C_StorageImageReadWithoutFormat, + SPV_C_StorageImageWriteWithoutFormat, SPV_C_MultiViewport, + SPV_C_SubgroupDispatch, SPV_C_NamedBarrier, SPV_C_PipeStorage, + SPV_C_GroupNonUniform, SPV_C_GroupNonUniformVote, + SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, + SPV_C_GroupNonUniformShuffle, SPV_C_GroupNonUniformShuffleRelative, + SPV_C_GroupNonUniformClustered, SPV_C_GroupNonUniformQuad, SPV_C_ShaderLayer, + SPV_C_ShaderViewportIndex, SPV_C_SubgroupBallotKHR, SPV_C_DrawParameters, + SPV_C_SubgroupVoteKHR, SPV_C_StorageBuffer16BitAccess, SPV_C_StorageUniform16, + SPV_C_StoragePushConstant16, SPV_C_StorageInputOutput16, SPV_C_DeviceGroup, + SPV_C_MultiView, SPV_C_VariablePointersStorageBuffer, SPV_C_VariablePointers, + SPV_C_AtomicStorageOps, SPV_C_SampleMaskPostDepthCoverage, + SPV_C_StorageBuffer8BitAccess, SPV_C_UniformAndStorageBuffer8BitAccess, + SPV_C_StoragePushConstant8, SPV_C_DenormPreserve, SPV_C_DenormFlushToZero, + SPV_C_SignedZeroInfNanPreserve, SPV_C_RoundingModeRTE, SPV_C_RoundingModeRTZ, + SPV_C_Float16ImageAMD, SPV_C_ImageGatherBiasLodAMD, SPV_C_FragmentMaskAMD, + SPV_C_StencilExportEXT, SPV_C_ImageReadWriteLodAMD, SPV_C_ShaderClockKHR, + SPV_C_SampleMaskOverrideCoverageNV, SPV_C_GeometryShaderPassthroughNV, + SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV, + SPV_C_ShaderStereoViewNV, SPV_C_PerViewAttributesNV, + SPV_C_FragmentFullyCoveredEXT, SPV_C_MeshShadingNV, SPV_C_ImageFootprintNV, + SPV_C_FragmentBarycentricNV, SPV_C_ComputeDerivativeGroupQuadsNV, + SPV_C_FragmentDensityEXT, SPV_C_GroupNonUniformPartitionedNV, + SPV_C_ShaderNonUniform, SPV_C_RuntimeDescriptorArray, + SPV_C_InputAttachmentArrayDynamicIndexing, + SPV_C_UniformTexelBufferArrayDynamicIndexing, + SPV_C_StorageTexelBufferArrayDynamicIndexing, + SPV_C_UniformBufferArrayNonUniformIndexing, + SPV_C_SampledImageArrayNonUniformIndexing, + SPV_C_StorageBufferArrayNonUniformIndexing, + SPV_C_StorageImageArrayNonUniformIndexing, + SPV_C_InputAttachmentArrayNonUniformIndexing, + SPV_C_UniformTexelBufferArrayNonUniformIndexing, + SPV_C_StorageTexelBufferArrayNonUniformIndexing, SPV_C_RayTracingNV, + SPV_C_VulkanMemoryModel, SPV_C_VulkanMemoryModelDeviceScope, + SPV_C_PhysicalStorageBufferAddresses, SPV_C_ComputeDerivativeGroupLinearNV, + SPV_C_CooperativeMatrixNV, SPV_C_FragmentShaderSampleInterlockEXT, + SPV_C_FragmentShaderShadingRateInterlockEXT, SPV_C_ShaderSMBuiltinsNV, + SPV_C_FragmentShaderPixelInterlockEXT, SPV_C_DemoteToHelperInvocationEXT, + SPV_C_SubgroupShuffleINTEL, SPV_C_SubgroupBufferBlockIOINTEL, + SPV_C_SubgroupImageBlockIOINTEL, SPV_C_SubgroupImageMediaBlockIOINTEL, + SPV_C_IntegerFunctions2INTEL, SPV_C_SubgroupAvcMotionEstimationINTEL, + SPV_C_SubgroupAvcMotionEstimationIntraINTEL, + SPV_C_SubgroupAvcMotionEstimationChromaINTEL + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>; +def SPV_AM_Physical32 : I32EnumAttrCase<"Physical32", 1>; +def SPV_AM_Physical64 : I32EnumAttrCase<"Physical64", 2>; +def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348>; + +def SPV_AddressingModelAttr : + I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [ + SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64, + SPV_AM_PhysicalStorageBuffer64 + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_BI_Position : I32EnumAttrCase<"Position", 0>; +def SPV_BI_PointSize : I32EnumAttrCase<"PointSize", 1>; +def SPV_BI_ClipDistance : I32EnumAttrCase<"ClipDistance", 3>; +def SPV_BI_CullDistance : I32EnumAttrCase<"CullDistance", 4>; +def SPV_BI_VertexId : I32EnumAttrCase<"VertexId", 5>; +def SPV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6>; +def SPV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7>; +def SPV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8>; +def SPV_BI_Layer : I32EnumAttrCase<"Layer", 9>; +def SPV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10>; +def SPV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11>; +def SPV_BI_TessLevelInner : I32EnumAttrCase<"TessLevelInner", 12>; +def SPV_BI_TessCoord : I32EnumAttrCase<"TessCoord", 13>; +def SPV_BI_PatchVertices : I32EnumAttrCase<"PatchVertices", 14>; +def SPV_BI_FragCoord : I32EnumAttrCase<"FragCoord", 15>; +def SPV_BI_PointCoord : I32EnumAttrCase<"PointCoord", 16>; +def SPV_BI_FrontFacing : I32EnumAttrCase<"FrontFacing", 17>; +def SPV_BI_SampleId : I32EnumAttrCase<"SampleId", 18>; +def SPV_BI_SamplePosition : I32EnumAttrCase<"SamplePosition", 19>; +def SPV_BI_SampleMask : I32EnumAttrCase<"SampleMask", 20>; +def SPV_BI_FragDepth : I32EnumAttrCase<"FragDepth", 22>; +def SPV_BI_HelperInvocation : I32EnumAttrCase<"HelperInvocation", 23>; +def SPV_BI_NumWorkgroups : I32EnumAttrCase<"NumWorkgroups", 24>; +def SPV_BI_WorkgroupSize : I32EnumAttrCase<"WorkgroupSize", 25>; +def SPV_BI_WorkgroupId : I32EnumAttrCase<"WorkgroupId", 26>; +def SPV_BI_LocalInvocationId : I32EnumAttrCase<"LocalInvocationId", 27>; +def SPV_BI_GlobalInvocationId : I32EnumAttrCase<"GlobalInvocationId", 28>; +def SPV_BI_LocalInvocationIndex : I32EnumAttrCase<"LocalInvocationIndex", 29>; +def SPV_BI_WorkDim : I32EnumAttrCase<"WorkDim", 30>; +def SPV_BI_GlobalSize : I32EnumAttrCase<"GlobalSize", 31>; +def SPV_BI_EnqueuedWorkgroupSize : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>; +def SPV_BI_GlobalOffset : I32EnumAttrCase<"GlobalOffset", 33>; +def SPV_BI_GlobalLinearId : I32EnumAttrCase<"GlobalLinearId", 34>; +def SPV_BI_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 36>; +def SPV_BI_SubgroupMaxSize : I32EnumAttrCase<"SubgroupMaxSize", 37>; +def SPV_BI_NumSubgroups : I32EnumAttrCase<"NumSubgroups", 38>; +def SPV_BI_NumEnqueuedSubgroups : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>; +def SPV_BI_SubgroupId : I32EnumAttrCase<"SubgroupId", 40>; +def SPV_BI_SubgroupLocalInvocationId : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>; +def SPV_BI_VertexIndex : I32EnumAttrCase<"VertexIndex", 42>; +def SPV_BI_InstanceIndex : I32EnumAttrCase<"InstanceIndex", 43>; +def SPV_BI_SubgroupEqMask : I32EnumAttrCase<"SubgroupEqMask", 4416>; +def SPV_BI_SubgroupGeMask : I32EnumAttrCase<"SubgroupGeMask", 4417>; +def SPV_BI_SubgroupGtMask : I32EnumAttrCase<"SubgroupGtMask", 4418>; +def SPV_BI_SubgroupLeMask : I32EnumAttrCase<"SubgroupLeMask", 4419>; +def SPV_BI_SubgroupLtMask : I32EnumAttrCase<"SubgroupLtMask", 4420>; +def SPV_BI_BaseVertex : I32EnumAttrCase<"BaseVertex", 4424>; +def SPV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425>; +def SPV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426>; +def SPV_BI_DeviceIndex : I32EnumAttrCase<"DeviceIndex", 4438>; +def SPV_BI_ViewIndex : I32EnumAttrCase<"ViewIndex", 4440>; +def SPV_BI_BaryCoordNoPerspAMD : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>; +def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>; +def SPV_BI_BaryCoordNoPerspSampleAMD : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>; +def SPV_BI_BaryCoordSmoothAMD : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>; +def SPV_BI_BaryCoordSmoothCentroidAMD : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>; +def SPV_BI_BaryCoordSmoothSampleAMD : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>; +def SPV_BI_BaryCoordPullModelAMD : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>; +def SPV_BI_FragStencilRefEXT : I32EnumAttrCase<"FragStencilRefEXT", 5014>; +def SPV_BI_ViewportMaskNV : I32EnumAttrCase<"ViewportMaskNV", 5253>; +def SPV_BI_SecondaryPositionNV : I32EnumAttrCase<"SecondaryPositionNV", 5257>; +def SPV_BI_SecondaryViewportMaskNV : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>; +def SPV_BI_PositionPerViewNV : I32EnumAttrCase<"PositionPerViewNV", 5261>; +def SPV_BI_ViewportMaskPerViewNV : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>; +def SPV_BI_FullyCoveredEXT : I32EnumAttrCase<"FullyCoveredEXT", 5264>; +def SPV_BI_TaskCountNV : I32EnumAttrCase<"TaskCountNV", 5274>; +def SPV_BI_PrimitiveCountNV : I32EnumAttrCase<"PrimitiveCountNV", 5275>; +def SPV_BI_PrimitiveIndicesNV : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>; +def SPV_BI_ClipDistancePerViewNV : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>; +def SPV_BI_CullDistancePerViewNV : I32EnumAttrCase<"CullDistancePerViewNV", 5278>; +def SPV_BI_LayerPerViewNV : I32EnumAttrCase<"LayerPerViewNV", 5279>; +def SPV_BI_MeshViewCountNV : I32EnumAttrCase<"MeshViewCountNV", 5280>; +def SPV_BI_MeshViewIndicesNV : I32EnumAttrCase<"MeshViewIndicesNV", 5281>; +def SPV_BI_BaryCoordNV : I32EnumAttrCase<"BaryCoordNV", 5286>; +def SPV_BI_BaryCoordNoPerspNV : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>; +def SPV_BI_FragSizeEXT : I32EnumAttrCase<"FragSizeEXT", 5292>; +def SPV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountEXT", 5293>; +def SPV_BI_LaunchIdNV : I32EnumAttrCase<"LaunchIdNV", 5319>; +def SPV_BI_LaunchSizeNV : I32EnumAttrCase<"LaunchSizeNV", 5320>; +def SPV_BI_WorldRayOriginNV : I32EnumAttrCase<"WorldRayOriginNV", 5321>; +def SPV_BI_WorldRayDirectionNV : I32EnumAttrCase<"WorldRayDirectionNV", 5322>; +def SPV_BI_ObjectRayOriginNV : I32EnumAttrCase<"ObjectRayOriginNV", 5323>; +def SPV_BI_ObjectRayDirectionNV : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>; +def SPV_BI_RayTminNV : I32EnumAttrCase<"RayTminNV", 5325>; +def SPV_BI_RayTmaxNV : I32EnumAttrCase<"RayTmaxNV", 5326>; +def SPV_BI_InstanceCustomIndexNV : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>; +def SPV_BI_ObjectToWorldNV : I32EnumAttrCase<"ObjectToWorldNV", 5330>; +def SPV_BI_WorldToObjectNV : I32EnumAttrCase<"WorldToObjectNV", 5331>; +def SPV_BI_HitTNV : I32EnumAttrCase<"HitTNV", 5332>; +def SPV_BI_HitKindNV : I32EnumAttrCase<"HitKindNV", 5333>; +def SPV_BI_IncomingRayFlagsNV : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>; +def SPV_BI_WarpsPerSMNV : I32EnumAttrCase<"WarpsPerSMNV", 5374>; +def SPV_BI_SMCountNV : I32EnumAttrCase<"SMCountNV", 5375>; +def SPV_BI_WarpIDNV : I32EnumAttrCase<"WarpIDNV", 5376>; +def SPV_BI_SMIDNV : I32EnumAttrCase<"SMIDNV", 5377>; + +def SPV_BuiltInAttr : + I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [ + SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance, + SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId, + SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter, + SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices, + SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId, + SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth, + SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize, + SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId, + SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize, + SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId, + SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups, + SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId, + SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex, + SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask, + SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex, + SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex, + SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD, + SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD, + SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD, + SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV, + SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV, + SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT, + SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV, + SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV, + SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV, + SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT, + SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV, + SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV, + SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV, + SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV, + SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV, + SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>; +def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>; +def SPV_D_Block : I32EnumAttrCase<"Block", 2>; +def SPV_D_BufferBlock : I32EnumAttrCase<"BufferBlock", 3>; +def SPV_D_RowMajor : I32EnumAttrCase<"RowMajor", 4>; +def SPV_D_ColMajor : I32EnumAttrCase<"ColMajor", 5>; +def SPV_D_ArrayStride : I32EnumAttrCase<"ArrayStride", 6>; +def SPV_D_MatrixStride : I32EnumAttrCase<"MatrixStride", 7>; +def SPV_D_GLSLShared : I32EnumAttrCase<"GLSLShared", 8>; +def SPV_D_GLSLPacked : I32EnumAttrCase<"GLSLPacked", 9>; +def SPV_D_CPacked : I32EnumAttrCase<"CPacked", 10>; +def SPV_D_BuiltIn : I32EnumAttrCase<"BuiltIn", 11>; +def SPV_D_NoPerspective : I32EnumAttrCase<"NoPerspective", 13>; +def SPV_D_Flat : I32EnumAttrCase<"Flat", 14>; +def SPV_D_Patch : I32EnumAttrCase<"Patch", 15>; +def SPV_D_Centroid : I32EnumAttrCase<"Centroid", 16>; +def SPV_D_Sample : I32EnumAttrCase<"Sample", 17>; +def SPV_D_Invariant : I32EnumAttrCase<"Invariant", 18>; +def SPV_D_Restrict : I32EnumAttrCase<"Restrict", 19>; +def SPV_D_Aliased : I32EnumAttrCase<"Aliased", 20>; +def SPV_D_Volatile : I32EnumAttrCase<"Volatile", 21>; +def SPV_D_Constant : I32EnumAttrCase<"Constant", 22>; +def SPV_D_Coherent : I32EnumAttrCase<"Coherent", 23>; +def SPV_D_NonWritable : I32EnumAttrCase<"NonWritable", 24>; +def SPV_D_NonReadable : I32EnumAttrCase<"NonReadable", 25>; +def SPV_D_Uniform : I32EnumAttrCase<"Uniform", 26>; +def SPV_D_UniformId : I32EnumAttrCase<"UniformId", 27>; +def SPV_D_SaturatedConversion : I32EnumAttrCase<"SaturatedConversion", 28>; +def SPV_D_Stream : I32EnumAttrCase<"Stream", 29>; +def SPV_D_Location : I32EnumAttrCase<"Location", 30>; +def SPV_D_Component : I32EnumAttrCase<"Component", 31>; +def SPV_D_Index : I32EnumAttrCase<"Index", 32>; +def SPV_D_Binding : I32EnumAttrCase<"Binding", 33>; +def SPV_D_DescriptorSet : I32EnumAttrCase<"DescriptorSet", 34>; +def SPV_D_Offset : I32EnumAttrCase<"Offset", 35>; +def SPV_D_XfbBuffer : I32EnumAttrCase<"XfbBuffer", 36>; +def SPV_D_XfbStride : I32EnumAttrCase<"XfbStride", 37>; +def SPV_D_FuncParamAttr : I32EnumAttrCase<"FuncParamAttr", 38>; +def SPV_D_FPRoundingMode : I32EnumAttrCase<"FPRoundingMode", 39>; +def SPV_D_FPFastMathMode : I32EnumAttrCase<"FPFastMathMode", 40>; +def SPV_D_LinkageAttributes : I32EnumAttrCase<"LinkageAttributes", 41>; +def SPV_D_NoContraction : I32EnumAttrCase<"NoContraction", 42>; +def SPV_D_InputAttachmentIndex : I32EnumAttrCase<"InputAttachmentIndex", 43>; +def SPV_D_Alignment : I32EnumAttrCase<"Alignment", 44>; +def SPV_D_MaxByteOffset : I32EnumAttrCase<"MaxByteOffset", 45>; +def SPV_D_AlignmentId : I32EnumAttrCase<"AlignmentId", 46>; +def SPV_D_MaxByteOffsetId : I32EnumAttrCase<"MaxByteOffsetId", 47>; +def SPV_D_NoSignedWrap : I32EnumAttrCase<"NoSignedWrap", 4469>; +def SPV_D_NoUnsignedWrap : I32EnumAttrCase<"NoUnsignedWrap", 4470>; +def SPV_D_ExplicitInterpAMD : I32EnumAttrCase<"ExplicitInterpAMD", 4999>; +def SPV_D_OverrideCoverageNV : I32EnumAttrCase<"OverrideCoverageNV", 5248>; +def SPV_D_PassthroughNV : I32EnumAttrCase<"PassthroughNV", 5250>; +def SPV_D_ViewportRelativeNV : I32EnumAttrCase<"ViewportRelativeNV", 5252>; +def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>; +def SPV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271>; +def SPV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272>; +def SPV_D_PerTaskNV : I32EnumAttrCase<"PerTaskNV", 5273>; +def SPV_D_PerVertexNV : I32EnumAttrCase<"PerVertexNV", 5285>; +def SPV_D_NonUniform : I32EnumAttrCase<"NonUniform", 5300>; +def SPV_D_RestrictPointer : I32EnumAttrCase<"RestrictPointer", 5355>; +def SPV_D_AliasedPointer : I32EnumAttrCase<"AliasedPointer", 5356>; +def SPV_D_CounterBuffer : I32EnumAttrCase<"CounterBuffer", 5634>; +def SPV_D_UserSemantic : I32EnumAttrCase<"UserSemantic", 5635>; +def SPV_D_UserTypeGOOGLE : I32EnumAttrCase<"UserTypeGOOGLE", 5636>; + +def SPV_DecorationAttr : + I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [ + SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock, + SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride, + SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn, + SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample, + SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant, + SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform, + SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location, + SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset, + SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode, + SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction, + SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset, + SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap, + SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV, + SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV, + SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV, + SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniform, SPV_D_RestrictPointer, + SPV_D_AliasedPointer, SPV_D_CounterBuffer, SPV_D_UserSemantic, + SPV_D_UserTypeGOOGLE + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_D_1D : I32EnumAttrCase<"Dim1D", 0>; +def SPV_D_2D : I32EnumAttrCase<"Dim2D", 1>; +def SPV_D_3D : I32EnumAttrCase<"Dim3D", 2>; +def SPV_D_Cube : I32EnumAttrCase<"Cube", 3>; +def SPV_D_Rect : I32EnumAttrCase<"Rect", 4>; +def SPV_D_Buffer : I32EnumAttrCase<"Buffer", 5>; +def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6>; + +def SPV_DimAttr : + I32EnumAttr<"Dim", "valid SPIR-V Dim", [ + SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer, + SPV_D_SubpassData + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_EM_Invocations : I32EnumAttrCase<"Invocations", 0>; +def SPV_EM_SpacingEqual : I32EnumAttrCase<"SpacingEqual", 1>; +def SPV_EM_SpacingFractionalEven : I32EnumAttrCase<"SpacingFractionalEven", 2>; +def SPV_EM_SpacingFractionalOdd : I32EnumAttrCase<"SpacingFractionalOdd", 3>; +def SPV_EM_VertexOrderCw : I32EnumAttrCase<"VertexOrderCw", 4>; +def SPV_EM_VertexOrderCcw : I32EnumAttrCase<"VertexOrderCcw", 5>; +def SPV_EM_PixelCenterInteger : I32EnumAttrCase<"PixelCenterInteger", 6>; +def SPV_EM_OriginUpperLeft : I32EnumAttrCase<"OriginUpperLeft", 7>; +def SPV_EM_OriginLowerLeft : I32EnumAttrCase<"OriginLowerLeft", 8>; +def SPV_EM_EarlyFragmentTests : I32EnumAttrCase<"EarlyFragmentTests", 9>; +def SPV_EM_PointMode : I32EnumAttrCase<"PointMode", 10>; +def SPV_EM_Xfb : I32EnumAttrCase<"Xfb", 11>; +def SPV_EM_DepthReplacing : I32EnumAttrCase<"DepthReplacing", 12>; +def SPV_EM_DepthGreater : I32EnumAttrCase<"DepthGreater", 14>; +def SPV_EM_DepthLess : I32EnumAttrCase<"DepthLess", 15>; +def SPV_EM_DepthUnchanged : I32EnumAttrCase<"DepthUnchanged", 16>; +def SPV_EM_LocalSize : I32EnumAttrCase<"LocalSize", 17>; +def SPV_EM_LocalSizeHint : I32EnumAttrCase<"LocalSizeHint", 18>; +def SPV_EM_InputPoints : I32EnumAttrCase<"InputPoints", 19>; +def SPV_EM_InputLines : I32EnumAttrCase<"InputLines", 20>; +def SPV_EM_InputLinesAdjacency : I32EnumAttrCase<"InputLinesAdjacency", 21>; +def SPV_EM_Triangles : I32EnumAttrCase<"Triangles", 22>; +def SPV_EM_InputTrianglesAdjacency : I32EnumAttrCase<"InputTrianglesAdjacency", 23>; +def SPV_EM_Quads : I32EnumAttrCase<"Quads", 24>; +def SPV_EM_Isolines : I32EnumAttrCase<"Isolines", 25>; +def SPV_EM_OutputVertices : I32EnumAttrCase<"OutputVertices", 26>; +def SPV_EM_OutputPoints : I32EnumAttrCase<"OutputPoints", 27>; +def SPV_EM_OutputLineStrip : I32EnumAttrCase<"OutputLineStrip", 28>; +def SPV_EM_OutputTriangleStrip : I32EnumAttrCase<"OutputTriangleStrip", 29>; +def SPV_EM_VecTypeHint : I32EnumAttrCase<"VecTypeHint", 30>; +def SPV_EM_ContractionOff : I32EnumAttrCase<"ContractionOff", 31>; +def SPV_EM_Initializer : I32EnumAttrCase<"Initializer", 33>; +def SPV_EM_Finalizer : I32EnumAttrCase<"Finalizer", 34>; +def SPV_EM_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 35>; +def SPV_EM_SubgroupsPerWorkgroup : I32EnumAttrCase<"SubgroupsPerWorkgroup", 36>; +def SPV_EM_SubgroupsPerWorkgroupId : I32EnumAttrCase<"SubgroupsPerWorkgroupId", 37>; +def SPV_EM_LocalSizeId : I32EnumAttrCase<"LocalSizeId", 38>; +def SPV_EM_LocalSizeHintId : I32EnumAttrCase<"LocalSizeHintId", 39>; +def SPV_EM_PostDepthCoverage : I32EnumAttrCase<"PostDepthCoverage", 4446>; +def SPV_EM_DenormPreserve : I32EnumAttrCase<"DenormPreserve", 4459>; +def SPV_EM_DenormFlushToZero : I32EnumAttrCase<"DenormFlushToZero", 4460>; +def SPV_EM_SignedZeroInfNanPreserve : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4461>; +def SPV_EM_RoundingModeRTE : I32EnumAttrCase<"RoundingModeRTE", 4462>; +def SPV_EM_RoundingModeRTZ : I32EnumAttrCase<"RoundingModeRTZ", 4463>; +def SPV_EM_StencilRefReplacingEXT : I32EnumAttrCase<"StencilRefReplacingEXT", 5027>; +def SPV_EM_OutputLinesNV : I32EnumAttrCase<"OutputLinesNV", 5269>; +def SPV_EM_OutputPrimitivesNV : I32EnumAttrCase<"OutputPrimitivesNV", 5270>; +def SPV_EM_DerivativeGroupQuadsNV : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289>; +def SPV_EM_DerivativeGroupLinearNV : I32EnumAttrCase<"DerivativeGroupLinearNV", 5290>; +def SPV_EM_OutputTrianglesNV : I32EnumAttrCase<"OutputTrianglesNV", 5298>; +def SPV_EM_PixelInterlockOrderedEXT : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366>; +def SPV_EM_PixelInterlockUnorderedEXT : I32EnumAttrCase<"PixelInterlockUnorderedEXT", 5367>; +def SPV_EM_SampleInterlockOrderedEXT : I32EnumAttrCase<"SampleInterlockOrderedEXT", 5368>; +def SPV_EM_SampleInterlockUnorderedEXT : I32EnumAttrCase<"SampleInterlockUnorderedEXT", 5369>; +def SPV_EM_ShadingRateInterlockOrderedEXT : I32EnumAttrCase<"ShadingRateInterlockOrderedEXT", 5370>; +def SPV_EM_ShadingRateInterlockUnorderedEXT : I32EnumAttrCase<"ShadingRateInterlockUnorderedEXT", 5371>; + +def SPV_ExecutionModeAttr : + I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [ + SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven, + SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw, + SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft, + SPV_EM_EarlyFragmentTests, SPV_EM_PointMode, SPV_EM_Xfb, SPV_EM_DepthReplacing, + SPV_EM_DepthGreater, SPV_EM_DepthLess, SPV_EM_DepthUnchanged, SPV_EM_LocalSize, + SPV_EM_LocalSizeHint, SPV_EM_InputPoints, SPV_EM_InputLines, + SPV_EM_InputLinesAdjacency, SPV_EM_Triangles, SPV_EM_InputTrianglesAdjacency, + SPV_EM_Quads, SPV_EM_Isolines, SPV_EM_OutputVertices, SPV_EM_OutputPoints, + SPV_EM_OutputLineStrip, SPV_EM_OutputTriangleStrip, SPV_EM_VecTypeHint, + SPV_EM_ContractionOff, SPV_EM_Initializer, SPV_EM_Finalizer, + SPV_EM_SubgroupSize, SPV_EM_SubgroupsPerWorkgroup, + SPV_EM_SubgroupsPerWorkgroupId, SPV_EM_LocalSizeId, SPV_EM_LocalSizeHintId, + SPV_EM_PostDepthCoverage, SPV_EM_DenormPreserve, SPV_EM_DenormFlushToZero, + SPV_EM_SignedZeroInfNanPreserve, SPV_EM_RoundingModeRTE, + SPV_EM_RoundingModeRTZ, SPV_EM_StencilRefReplacingEXT, SPV_EM_OutputLinesNV, + SPV_EM_OutputPrimitivesNV, SPV_EM_DerivativeGroupQuadsNV, + SPV_EM_DerivativeGroupLinearNV, SPV_EM_OutputTrianglesNV, + SPV_EM_PixelInterlockOrderedEXT, SPV_EM_PixelInterlockUnorderedEXT, + SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT, + SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_EM_Vertex : I32EnumAttrCase<"Vertex", 0>; +def SPV_EM_TessellationControl : I32EnumAttrCase<"TessellationControl", 1>; +def SPV_EM_TessellationEvaluation : I32EnumAttrCase<"TessellationEvaluation", 2>; +def SPV_EM_Geometry : I32EnumAttrCase<"Geometry", 3>; +def SPV_EM_Fragment : I32EnumAttrCase<"Fragment", 4>; +def SPV_EM_GLCompute : I32EnumAttrCase<"GLCompute", 5>; +def SPV_EM_Kernel : I32EnumAttrCase<"Kernel", 6>; +def SPV_EM_TaskNV : I32EnumAttrCase<"TaskNV", 5267>; +def SPV_EM_MeshNV : I32EnumAttrCase<"MeshNV", 5268>; +def SPV_EM_RayGenerationNV : I32EnumAttrCase<"RayGenerationNV", 5313>; +def SPV_EM_IntersectionNV : I32EnumAttrCase<"IntersectionNV", 5314>; +def SPV_EM_AnyHitNV : I32EnumAttrCase<"AnyHitNV", 5315>; +def SPV_EM_ClosestHitNV : I32EnumAttrCase<"ClosestHitNV", 5316>; +def SPV_EM_MissNV : I32EnumAttrCase<"MissNV", 5317>; +def SPV_EM_CallableNV : I32EnumAttrCase<"CallableNV", 5318>; + +def SPV_ExecutionModelAttr : + I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [ + SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation, + SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel, + SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV, + SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_FC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_FC_Inline : BitEnumAttrCase<"Inline", 0x0001>; +def SPV_FC_DontInline : BitEnumAttrCase<"DontInline", 0x0002>; +def SPV_FC_Pure : BitEnumAttrCase<"Pure", 0x0004>; +def SPV_FC_Const : BitEnumAttrCase<"Const", 0x0008>; + +def SPV_FunctionControlAttr : + BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ + SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>; +def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1>; +def SPV_IF_Rgba16f : I32EnumAttrCase<"Rgba16f", 2>; +def SPV_IF_R32f : I32EnumAttrCase<"R32f", 3>; +def SPV_IF_Rgba8 : I32EnumAttrCase<"Rgba8", 4>; +def SPV_IF_Rgba8Snorm : I32EnumAttrCase<"Rgba8Snorm", 5>; +def SPV_IF_Rg32f : I32EnumAttrCase<"Rg32f", 6>; +def SPV_IF_Rg16f : I32EnumAttrCase<"Rg16f", 7>; +def SPV_IF_R11fG11fB10f : I32EnumAttrCase<"R11fG11fB10f", 8>; +def SPV_IF_R16f : I32EnumAttrCase<"R16f", 9>; +def SPV_IF_Rgba16 : I32EnumAttrCase<"Rgba16", 10>; +def SPV_IF_Rgb10A2 : I32EnumAttrCase<"Rgb10A2", 11>; +def SPV_IF_Rg16 : I32EnumAttrCase<"Rg16", 12>; +def SPV_IF_Rg8 : I32EnumAttrCase<"Rg8", 13>; +def SPV_IF_R16 : I32EnumAttrCase<"R16", 14>; +def SPV_IF_R8 : I32EnumAttrCase<"R8", 15>; +def SPV_IF_Rgba16Snorm : I32EnumAttrCase<"Rgba16Snorm", 16>; +def SPV_IF_Rg16Snorm : I32EnumAttrCase<"Rg16Snorm", 17>; +def SPV_IF_Rg8Snorm : I32EnumAttrCase<"Rg8Snorm", 18>; +def SPV_IF_R16Snorm : I32EnumAttrCase<"R16Snorm", 19>; +def SPV_IF_R8Snorm : I32EnumAttrCase<"R8Snorm", 20>; +def SPV_IF_Rgba32i : I32EnumAttrCase<"Rgba32i", 21>; +def SPV_IF_Rgba16i : I32EnumAttrCase<"Rgba16i", 22>; +def SPV_IF_Rgba8i : I32EnumAttrCase<"Rgba8i", 23>; +def SPV_IF_R32i : I32EnumAttrCase<"R32i", 24>; +def SPV_IF_Rg32i : I32EnumAttrCase<"Rg32i", 25>; +def SPV_IF_Rg16i : I32EnumAttrCase<"Rg16i", 26>; +def SPV_IF_Rg8i : I32EnumAttrCase<"Rg8i", 27>; +def SPV_IF_R16i : I32EnumAttrCase<"R16i", 28>; +def SPV_IF_R8i : I32EnumAttrCase<"R8i", 29>; +def SPV_IF_Rgba32ui : I32EnumAttrCase<"Rgba32ui", 30>; +def SPV_IF_Rgba16ui : I32EnumAttrCase<"Rgba16ui", 31>; +def SPV_IF_Rgba8ui : I32EnumAttrCase<"Rgba8ui", 32>; +def SPV_IF_R32ui : I32EnumAttrCase<"R32ui", 33>; +def SPV_IF_Rgb10a2ui : I32EnumAttrCase<"Rgb10a2ui", 34>; +def SPV_IF_Rg32ui : I32EnumAttrCase<"Rg32ui", 35>; +def SPV_IF_Rg16ui : I32EnumAttrCase<"Rg16ui", 36>; +def SPV_IF_Rg8ui : I32EnumAttrCase<"Rg8ui", 37>; +def SPV_IF_R16ui : I32EnumAttrCase<"R16ui", 38>; +def SPV_IF_R8ui : I32EnumAttrCase<"R8ui", 39>; + +def SPV_ImageFormatAttr : + I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [ + SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8, + SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f, + SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8, + SPV_IF_R16, SPV_IF_R8, SPV_IF_Rgba16Snorm, SPV_IF_Rg16Snorm, SPV_IF_Rg8Snorm, + SPV_IF_R16Snorm, SPV_IF_R8Snorm, SPV_IF_Rgba32i, SPV_IF_Rgba16i, SPV_IF_Rgba8i, + SPV_IF_R32i, SPV_IF_Rg32i, SPV_IF_Rg16i, SPV_IF_Rg8i, SPV_IF_R16i, SPV_IF_R8i, + SPV_IF_Rgba32ui, SPV_IF_Rgba16ui, SPV_IF_Rgba8ui, SPV_IF_R32ui, + SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui, + SPV_IF_R8ui + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_LT_Export : I32EnumAttrCase<"Export", 0>; +def SPV_LT_Import : I32EnumAttrCase<"Import", 1>; + +def SPV_LinkageTypeAttr : + I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [ + SPV_LT_Export, SPV_LT_Import + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_LC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_LC_Unroll : BitEnumAttrCase<"Unroll", 0x0001>; +def SPV_LC_DontUnroll : BitEnumAttrCase<"DontUnroll", 0x0002>; +def SPV_LC_DependencyInfinite : BitEnumAttrCase<"DependencyInfinite", 0x0004>; +def SPV_LC_DependencyLength : BitEnumAttrCase<"DependencyLength", 0x0008>; +def SPV_LC_MinIterations : BitEnumAttrCase<"MinIterations", 0x0010>; +def SPV_LC_MaxIterations : BitEnumAttrCase<"MaxIterations", 0x0020>; +def SPV_LC_IterationMultiple : BitEnumAttrCase<"IterationMultiple", 0x0040>; +def SPV_LC_PeelCount : BitEnumAttrCase<"PeelCount", 0x0080>; +def SPV_LC_PartialCount : BitEnumAttrCase<"PartialCount", 0x0100>; + +def SPV_LoopControlAttr : + BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ + SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite, + SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, + SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_MA_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_MA_Volatile : BitEnumAttrCase<"Volatile", 0x0001>; +def SPV_MA_Aligned : BitEnumAttrCase<"Aligned", 0x0002>; +def SPV_MA_Nontemporal : BitEnumAttrCase<"Nontemporal", 0x0004>; +def SPV_MA_MakePointerAvailable : BitEnumAttrCase<"MakePointerAvailable", 0x0008>; +def SPV_MA_MakePointerVisible : BitEnumAttrCase<"MakePointerVisible", 0x0010>; +def SPV_MA_NonPrivatePointer : BitEnumAttrCase<"NonPrivatePointer", 0x0020>; + +def SPV_MemoryAccessAttr : + BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ + SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal, + SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible, + SPV_MA_NonPrivatePointer + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_MM_Simple : I32EnumAttrCase<"Simple", 0>; +def SPV_MM_GLSL450 : I32EnumAttrCase<"GLSL450", 1>; +def SPV_MM_OpenCL : I32EnumAttrCase<"OpenCL", 2>; +def SPV_MM_Vulkan : I32EnumAttrCase<"Vulkan", 3>; + +def SPV_MemoryModelAttr : + I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [ + SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_MS_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_MS_Acquire : BitEnumAttrCase<"Acquire", 0x0002>; +def SPV_MS_Release : BitEnumAttrCase<"Release", 0x0004>; +def SPV_MS_AcquireRelease : BitEnumAttrCase<"AcquireRelease", 0x0008>; +def SPV_MS_SequentiallyConsistent : BitEnumAttrCase<"SequentiallyConsistent", 0x0010>; +def SPV_MS_UniformMemory : BitEnumAttrCase<"UniformMemory", 0x0040>; +def SPV_MS_SubgroupMemory : BitEnumAttrCase<"SubgroupMemory", 0x0080>; +def SPV_MS_WorkgroupMemory : BitEnumAttrCase<"WorkgroupMemory", 0x0100>; +def SPV_MS_CrossWorkgroupMemory : BitEnumAttrCase<"CrossWorkgroupMemory", 0x0200>; +def SPV_MS_AtomicCounterMemory : BitEnumAttrCase<"AtomicCounterMemory", 0x0400>; +def SPV_MS_ImageMemory : BitEnumAttrCase<"ImageMemory", 0x0800>; +def SPV_MS_OutputMemory : BitEnumAttrCase<"OutputMemory", 0x1000>; +def SPV_MS_MakeAvailable : BitEnumAttrCase<"MakeAvailable", 0x2000>; +def SPV_MS_MakeVisible : BitEnumAttrCase<"MakeVisible", 0x4000>; +def SPV_MS_Volatile : BitEnumAttrCase<"Volatile", 0x8000>; + +def SPV_MemorySemanticsAttr : + BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", [ + SPV_MS_None, SPV_MS_Acquire, SPV_MS_Release, SPV_MS_AcquireRelease, + SPV_MS_SequentiallyConsistent, SPV_MS_UniformMemory, SPV_MS_SubgroupMemory, + SPV_MS_WorkgroupMemory, SPV_MS_CrossWorkgroupMemory, + SPV_MS_AtomicCounterMemory, SPV_MS_ImageMemory, SPV_MS_OutputMemory, + SPV_MS_MakeAvailable, SPV_MS_MakeVisible, SPV_MS_Volatile + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_S_CrossDevice : I32EnumAttrCase<"CrossDevice", 0>; +def SPV_S_Device : I32EnumAttrCase<"Device", 1>; +def SPV_S_Workgroup : I32EnumAttrCase<"Workgroup", 2>; +def SPV_S_Subgroup : I32EnumAttrCase<"Subgroup", 3>; +def SPV_S_Invocation : I32EnumAttrCase<"Invocation", 4>; +def SPV_S_QueueFamily : I32EnumAttrCase<"QueueFamily", 5>; + +def SPV_ScopeAttr : + I32EnumAttr<"Scope", "valid SPIR-V Scope", [ + SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup, + SPV_S_Invocation, SPV_S_QueueFamily + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_SC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_SC_Flatten : BitEnumAttrCase<"Flatten", 0x0001>; +def SPV_SC_DontFlatten : BitEnumAttrCase<"DontFlatten", 0x0002>; + +def SPV_SelectionControlAttr : + BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [ + SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten + ]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_SC_UniformConstant : I32EnumAttrCase<"UniformConstant", 0>; +def SPV_SC_Input : I32EnumAttrCase<"Input", 1>; +def SPV_SC_Uniform : I32EnumAttrCase<"Uniform", 2>; +def SPV_SC_Output : I32EnumAttrCase<"Output", 3>; +def SPV_SC_Workgroup : I32EnumAttrCase<"Workgroup", 4>; +def SPV_SC_CrossWorkgroup : I32EnumAttrCase<"CrossWorkgroup", 5>; +def SPV_SC_Private : I32EnumAttrCase<"Private", 6>; +def SPV_SC_Function : I32EnumAttrCase<"Function", 7>; +def SPV_SC_Generic : I32EnumAttrCase<"Generic", 8>; +def SPV_SC_PushConstant : I32EnumAttrCase<"PushConstant", 9>; +def SPV_SC_AtomicCounter : I32EnumAttrCase<"AtomicCounter", 10>; +def SPV_SC_Image : I32EnumAttrCase<"Image", 11>; +def SPV_SC_StorageBuffer : I32EnumAttrCase<"StorageBuffer", 12>; +def SPV_SC_CallableDataNV : I32EnumAttrCase<"CallableDataNV", 5328>; +def SPV_SC_IncomingCallableDataNV : I32EnumAttrCase<"IncomingCallableDataNV", 5329>; +def SPV_SC_RayPayloadNV : I32EnumAttrCase<"RayPayloadNV", 5338>; +def SPV_SC_HitAttributeNV : I32EnumAttrCase<"HitAttributeNV", 5339>; +def SPV_SC_IncomingRayPayloadNV : I32EnumAttrCase<"IncomingRayPayloadNV", 5342>; +def SPV_SC_ShaderRecordBufferNV : I32EnumAttrCase<"ShaderRecordBufferNV", 5343>; +def SPV_SC_PhysicalStorageBuffer : I32EnumAttrCase<"PhysicalStorageBuffer", 5349>; + +def SPV_StorageClassAttr : + I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [ + SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output, + SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function, + SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image, + SPV_SC_StorageBuffer, SPV_SC_CallableDataNV, SPV_SC_IncomingCallableDataNV, + SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV, + SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBuffer + ]> { + let cppNamespace = "::mlir::spirv"; +} + +// End enum section. Generated from SPIR-V spec; DO NOT MODIFY! + +// Enums added manually that are not part of SPIR-V spec + +def SPV_IDI_NoDepth : I32EnumAttrCase<"NoDepth", 0>; +def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>; +def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>; + +def SPV_DepthAttr : + I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification", + [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>; +def SPV_IAI_Arrayed : I32EnumAttrCase<"Arrayed", 1>; + +def SPV_ArrayedAttr : + I32EnumAttr<"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification", + [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>; +def SPV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>; + +def SPV_SamplingAttr: + I32EnumAttr<"ImageSamplingInfo", "valid SPIR-V Image Sampling specification", + [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]> { + let cppNamespace = "::mlir::spirv"; +} + +def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>; +def SPV_ISUI_NeedSampler : I32EnumAttrCase<"NeedSampler", 1>; +def SPV_ISUI_NoSampler : I32EnumAttrCase<"NoSampler", 2>; + +def SPV_SamplerUseAttr: + I32EnumAttr<"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification", + [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]> { + let cppNamespace = "::mlir::spirv"; +} + +//===----------------------------------------------------------------------===// +// SPIR-V type definitions +//===----------------------------------------------------------------------===// + +def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; +def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; +def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; +def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; + +// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types +// for the definition of the following types and type categories. + +def SPV_Void : TypeAlias; +def SPV_Bool : IntOfWidths<[1]>; +def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>; +def SPV_Float : FloatOfWidths<[16, 32, 64]>; +def SPV_Float16or32 : FloatOfWidths<[16, 32]>; +def SPV_Vector : VectorOfLengthAndType<[2, 3, 4], + [SPV_Bool, SPV_Integer, SPV_Float]>; +// Component type check is done in the type parser for the following SPIR-V +// dialect-specific types so we use "Any" here. +def SPV_AnyPtr : Type; +def SPV_AnyArray : Type; +def SPV_AnyRTArray : Type; +def SPV_AnyStruct : Type; + +def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; +def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; +def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>; +def SPV_Composite : + AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; +def SPV_Type : AnyTypeOf<[ + SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct + ]>; + +class SPV_ScalarOrVectorOf : + AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; + +def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; +def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; + +class SPV_Vec4 : VectorOfLengthAndType<[4], [type]>; +def SPV_IntVec4 : SPV_Vec4; +def SPV_I32Vec4 : SPV_Vec4; + +// TODO(antiagainst): Use a more appropriate way to model optional operands +class SPV_Optional : Variadic; + +// TODO(ravishankarm): From 1.4, this should also include Composite type. +def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; + +//===----------------------------------------------------------------------===// +// SPIR-V OpTrait definitions +//===----------------------------------------------------------------------===// + +// Check that an op can only be used within the scope of a FuncOp. +def InFunctionScope : PredOpTrait< + "op must appear in a 'func' block", + CPred<"($_op.getParentOfType())">>; + +// Check that an op can only be used within the scope of a SPIR-V ModuleOp. +def InModuleScope : PredOpTrait< + "op must appear in a 'spv.module' block", + CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; + +//===----------------------------------------------------------------------===// +// SPIR-V opcode specification +//===----------------------------------------------------------------------===// + +class SPV_OpCode { + // Name used as reference to retrieve the opcode + string opname = name; + + // Opcode associated with the name + int opcode = val; +} + +// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY! + +def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>; +def SPV_OC_OpUndef : I32EnumAttrCase<"OpUndef", 1>; +def SPV_OC_OpSourceContinued : I32EnumAttrCase<"OpSourceContinued", 2>; +def SPV_OC_OpSource : I32EnumAttrCase<"OpSource", 3>; +def SPV_OC_OpSourceExtension : I32EnumAttrCase<"OpSourceExtension", 4>; +def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; +def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>; +def SPV_OC_OpString : I32EnumAttrCase<"OpString", 7>; +def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; +def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; +def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; +def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>; +def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>; +def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>; +def SPV_OC_OpCapability : I32EnumAttrCase<"OpCapability", 17>; +def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>; +def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>; +def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>; +def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; +def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; +def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; +def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; +def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; +def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; +def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; +def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>; +def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>; +def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>; +def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>; +def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>; +def SPV_OC_OpSpecConstantTrue : I32EnumAttrCase<"OpSpecConstantTrue", 48>; +def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>; +def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>; +def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>; +def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; +def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; +def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>; +def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>; +def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; +def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; +def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; +def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; +def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; +def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; +def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; +def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; +def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; +def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; +def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; +def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; +def SPV_OC_OpConvertUToF : I32EnumAttrCase<"OpConvertUToF", 112>; +def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; +def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; +def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; +def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; +def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; +def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; +def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; +def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>; +def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>; +def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>; +def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>; +def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>; +def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>; +def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>; +def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>; +def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>; +def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; +def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; +def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; +def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; +def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; +def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; +def SPV_OC_OpLogicalAnd : I32EnumAttrCase<"OpLogicalAnd", 167>; +def SPV_OC_OpLogicalNot : I32EnumAttrCase<"OpLogicalNot", 168>; +def SPV_OC_OpSelect : I32EnumAttrCase<"OpSelect", 169>; +def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>; +def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>; +def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>; +def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>; +def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>; +def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>; +def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>; +def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>; +def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; +def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; +def SPV_OC_OpFOrdEqual : I32EnumAttrCase<"OpFOrdEqual", 180>; +def SPV_OC_OpFUnordEqual : I32EnumAttrCase<"OpFUnordEqual", 181>; +def SPV_OC_OpFOrdNotEqual : I32EnumAttrCase<"OpFOrdNotEqual", 182>; +def SPV_OC_OpFUnordNotEqual : I32EnumAttrCase<"OpFUnordNotEqual", 183>; +def SPV_OC_OpFOrdLessThan : I32EnumAttrCase<"OpFOrdLessThan", 184>; +def SPV_OC_OpFUnordLessThan : I32EnumAttrCase<"OpFUnordLessThan", 185>; +def SPV_OC_OpFOrdGreaterThan : I32EnumAttrCase<"OpFOrdGreaterThan", 186>; +def SPV_OC_OpFUnordGreaterThan : I32EnumAttrCase<"OpFUnordGreaterThan", 187>; +def SPV_OC_OpFOrdLessThanEqual : I32EnumAttrCase<"OpFOrdLessThanEqual", 188>; +def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>; +def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>; +def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>; +def SPV_OC_OpShiftRightLogical : I32EnumAttrCase<"OpShiftRightLogical", 194>; +def SPV_OC_OpShiftRightArithmetic : I32EnumAttrCase<"OpShiftRightArithmetic", 195>; +def SPV_OC_OpShiftLeftLogical : I32EnumAttrCase<"OpShiftLeftLogical", 196>; +def SPV_OC_OpBitwiseOr : I32EnumAttrCase<"OpBitwiseOr", 197>; +def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>; +def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>; +def SPV_OC_OpNot : I32EnumAttrCase<"OpNot", 200>; +def SPV_OC_OpBitFieldInsert : I32EnumAttrCase<"OpBitFieldInsert", 201>; +def SPV_OC_OpBitFieldSExtract : I32EnumAttrCase<"OpBitFieldSExtract", 202>; +def SPV_OC_OpBitFieldUExtract : I32EnumAttrCase<"OpBitFieldUExtract", 203>; +def SPV_OC_OpBitReverse : I32EnumAttrCase<"OpBitReverse", 204>; +def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>; +def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; +def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; +def SPV_OC_OpAtomicCompareExchangeWeak : I32EnumAttrCase<"OpAtomicCompareExchangeWeak", 231>; +def SPV_OC_OpAtomicIIncrement : I32EnumAttrCase<"OpAtomicIIncrement", 232>; +def SPV_OC_OpAtomicIDecrement : I32EnumAttrCase<"OpAtomicIDecrement", 233>; +def SPV_OC_OpAtomicIAdd : I32EnumAttrCase<"OpAtomicIAdd", 234>; +def SPV_OC_OpAtomicISub : I32EnumAttrCase<"OpAtomicISub", 235>; +def SPV_OC_OpAtomicSMin : I32EnumAttrCase<"OpAtomicSMin", 236>; +def SPV_OC_OpAtomicUMin : I32EnumAttrCase<"OpAtomicUMin", 237>; +def SPV_OC_OpAtomicSMax : I32EnumAttrCase<"OpAtomicSMax", 238>; +def SPV_OC_OpAtomicUMax : I32EnumAttrCase<"OpAtomicUMax", 239>; +def SPV_OC_OpAtomicAnd : I32EnumAttrCase<"OpAtomicAnd", 240>; +def SPV_OC_OpAtomicOr : I32EnumAttrCase<"OpAtomicOr", 241>; +def SPV_OC_OpAtomicXor : I32EnumAttrCase<"OpAtomicXor", 242>; +def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>; +def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; +def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>; +def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; +def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; +def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; +def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; +def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; +def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; +def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; +def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; +def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; + +def SPV_OpcodeAttr : + I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ + SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource, + SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString, + SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, + SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, + SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, + SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, + SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, + SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, + SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, + SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU, + SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, + SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, + SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, + SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, + SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, + SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, + SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, + SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, + SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, + SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, + SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, + SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, + SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, + SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR + ]> { + let cppNamespace = "::mlir::spirv"; +} + +// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! + +//===----------------------------------------------------------------------===// +// SPIR-V op definitions +//===----------------------------------------------------------------------===// + +// Base class for all SPIR-V ops. +class SPV_Op traits = []> : + Op { + + // For each SPIR-V op, the following static functions need to be defined + // in SPVOps.cpp: + // + // * static ParseResult parse(OpAsmParser &parser, + // OperationState &result) + // * static void print(OpAsmPrinter &p, op) + // * static LogicalResult verify( op) + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(*this, p); }]; + let verifier = [{ return ::verify(*this); }]; + + // Specifies whether this op has a direct corresponding SPIR-V binary + // instruction opcode. The (de)serializer use this field to determine whether + // to auto-generate an entry in the (de)serialization dispatch table for this + // op. + bit hasOpcode = 1; + + // Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1. + string spirvOpName = "Op" # mnemonic; + + // Controls whether to auto-generate this op's (de)serialization method. + // If set, it results in generation of the following methods: + // + // ```c++ + // template Serializer::processOp(OpTy op); + // template Deserializer::processOp(ArrayRef); + // ``` + // + // If this field is not set, then manual implementation of a specialization of + // these methods is required. + // + // Note: + // 1) If hasOpcode is set but autogenSerialization is not set, the + // (de)serializer dispatch method still calls the above method for + // (de)serializing this op. + // 2) If hasOpcode is not set, but autogenSerialization is set, the + // above methods for (de)serialization are generated, but there is no + // entry added in the dispatch tables to invoke these methods. The + // dispatch needs to be handled manually. SPV_ExtInstOps are an + // example of this. + bit autogenSerialization = 1; +} + +class SPV_UnaryOp traits = []> : + SPV_Op { + let arguments = (ins + SPV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return ::parseUnaryOp(parser, result); }]; + let printer = [{ return ::printUnaryOp(getOperation(), p); }]; + // No additional verification needed in addition to the ODS-generated ones. + let verifier = [{ return success(); }]; +} + +class SPV_BinaryOp traits = []> : + SPV_Op { + let arguments = (ins + SPV_ScalarOrVectorOf:$operand1, + SPV_ScalarOrVectorOf:$operand2 + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + // No additional verification needed in addition to the ODS-generated ones. + let verifier = [{ return success(); }]; +} + +class SPV_ExtInstOp traits = []> : + SPV_Op { + + // Extended instruction sets have no direct opcode (they share the + // same `OpExtInst` instruction). So the hasOpcode field is set to + // false. So no entry corresponding to these ops are added in the + // dispatch functions for (de)serialization. The methods for + // (de)serialization are still automatically generated (since + // autogenSerialization remains 1). A separate method is generated + // for dispatching extended instruction set ops. + let hasOpcode = 0; + + // Opcode within extended instruction set. + int extendedInstOpcode = opcode; + + // Name used to import the extended instruction set. + string extendedInstSetName = setName; +} + +#endif // SPIRV_BASE diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6a4264884238ad9fc2a16cdb0ff229a1d6f2b40e --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h @@ -0,0 +1,49 @@ +//===- SPIRVBinaryUtils.cpp - SPIR-V Binary Module Utils --------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares common utilities for SPIR-V binary module. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_ +#define MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_ + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Support/LogicalResult.h" + +#include + +namespace mlir { +namespace spirv { + +/// SPIR-V binary header word count +constexpr unsigned kHeaderWordCount = 5; + +/// SPIR-V magic number +constexpr uint32_t kMagicNumber = 0x07230203; + +/// The serializer tool ID registered to the Khronos Group +constexpr uint32_t kGeneratorNumber = 22; + +/// Auto-generated getOpcode<*Op>() specializations +#define GET_SPIRV_SERIALIZATION_UTILS +#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" + +/// Appends a SPRI-V module header to `header` with the given `idBound`. +void appendModuleHeader(SmallVectorImpl &header, uint32_t idBound); + +/// Returns the word-count-prefixed opcode for an SPIR-V instruction. +uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode); + +/// Encodes an SPIR-V `literal` string into the given `binary` vector. +LogicalResult encodeStringLiteralInto(SmallVectorImpl &binary, + StringRef literal); +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td new file mode 100644 index 0000000000000000000000000000000000000000..360edeec52d6ef57ab15e038174c7707fde69add --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td @@ -0,0 +1,523 @@ +//===-- SPIRVBitOps.td - MLIR SPIR-V Bit Ops -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains bit ops for the SPIR-V dialect. It corresponds +// to "3.32.13. Bit Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_BIT_OPS +#define SPIRV_BIT_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +class SPV_BitBinaryOp traits = []> : + // All the operands type used in bit instructions are SPV_Integer. + SPV_BinaryOp; + +class SPV_BitFieldExtractOp traits = []> : + SPV_Op { + let arguments = (ins + SPV_ScalarOrVectorOf:$base, + SPV_Integer:$offset, + SPV_Integer:$count + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return ::parseBitFieldExtractOp(parser, result); }]; + let printer = [{ ::printBitFieldExtractOp(this->getOperation(), p); }]; + let verifier = [{ return ::verifyBitFieldExtractOp(this->getOperation()); }]; +} + +class SPV_BitUnaryOp traits = []> : + SPV_UnaryOp; + +class SPV_ShiftOp traits = []> : + SPV_BinaryOp { + let parser = [{ return ::parseShiftOp(parser, result); }]; + let printer = [{ ::printShiftOp(this->getOperation(), p); }]; + let verifier = [{ return ::verifyShiftOp(this->getOperation()); }]; +} + +// ----- + +def SPV_BitCountOp : SPV_BitUnaryOp<"BitCount", []> { + let summary = "Count the number of set bits in an object."; + + let description = [{ + Results are computed per component. + + Result Type must be a scalar or vector of integer type. The components + must be wide enough to hold the unsigned Width of Base as an unsigned + value. That is, no sign bit is needed or counted when checking for a + wide enough result width. + + Base must be a scalar or vector of integer type. It must have the same + number of components as Result Type. + + The result is the unsigned value that is the number of bits in Base that + are 1. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitcount-op ::= ssa-id `=` `spv.BitCount` ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.BitCount %0: i32 + %3 = spv.BitCount %1: vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> { + let summary = [{ + Make a copy of an object, with a modified bit field that comes from + another object. + }]; + + let description = [{ + Results are computed per component. + + Result Type must be a scalar or vector of integer type. + + The type of Base and Insert must be the same as Result Type. + + Any result bits numbered outside [Offset, Offset + Count - 1] + (inclusive) will come from the corresponding bits in Base. + + Any result bits numbered in [Offset, Offset + Count - 1] come, in + order, from the bits numbered [0, Count - 1] of Insert. + + Count must be an integer type scalar. Count is the number of bits taken + from Insert. It will be consumed as an unsigned value. Count can be 0, + in which case the result will be Base. + + Offset must be an integer type scalar. Offset is the lowest-order bit + of the bit field. It will be consumed as an unsigned value. + + The resulting value is undefined if Count or Offset or their sum is + greater than the number of bits in the result. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitfield-insert-op ::= ssa-id `=` `spv.BitFieldInsert` ssa-use `,` ssa-use + `,` ssa-use `,` ssa-use + `:` integer-scalar-vector-type + `,` integer-type `,` integer-type + ``` + + For example: + + ``` + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<3xi32>, i8, i8 + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOf:$base, + SPV_ScalarOrVectorOf:$insert, + SPV_Integer:$offset, + SPV_Integer:$count + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); +} + +// ----- + +def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", []> { + let summary = "Extract a bit field from an object, with sign extension."; + + let description = [{ + Results are computed per component. + + Result Type must be a scalar or vector of integer type. + + The type of Base must be the same as Result Type. + + If Count is greater than 0: The bits of Base numbered in [Offset, Offset + + Count - 1] (inclusive) become the bits numbered [0, Count - 1] of the + result. The remaining bits of the result will all be the same as bit + Offset + Count - 1 of Base. + + Count must be an integer type scalar. Count is the number of bits + extracted from Base. It will be consumed as an unsigned value. Count can + be 0, in which case the result will be 0. + + Offset must be an integer type scalar. Offset is the lowest-order bit + of the bit field to extract from Base. It will be consumed as an + unsigned value. + + The resulting value is undefined if Count or Offset or their sum is + greater than the number of bits in the result. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitfield-extract-s-op ::= ssa-id `=` `spv.BitFieldSExtract` ssa-use + `,` ssa-use `,` ssa-use + `:` integer-scalar-vector-type + `,` integer-type `,` integer-type + ``` + + For example: + + ``` + %0 = spv.BitFieldSExtract %base, %offset, %count : vector<3xi32>, i8, i8 + ``` + }]; +} + +// ----- + +def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", []> { + let summary = "Extract a bit field from an object, without sign extension."; + + let description = [{ + The semantics are the same as with OpBitFieldSExtract with the exception + that there is no sign extension. The remaining bits of the result will + all be 0. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitfield-extract-u-op ::= ssa-id `=` `spv.BitFieldUExtract` ssa-use + `,` ssa-use `,` ssa-use + `:` integer-scalar-vector-type + `,` integer-type `,` integer-type + ``` + + For example: + + ``` + %0 = spv.BitFieldUExtract %base, %offset, %count : vector<3xi32>, i8, i8 + ``` + }]; +} + +// ----- + +def SPV_BitReverseOp : SPV_BitUnaryOp<"BitReverse", []> { + let summary = "Reverse the bits in an object."; + + let description = [{ + Results are computed per component. + + Result Type must be a scalar or vector of integer type. + + The type of Base must be the same as Result Type. + + The bit-number n of the result will be taken from bit-number Width - 1 - + n of Base, where Width is the OpTypeInt operand of the Result Type. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitreverse-op ::= ssa-id `=` `spv.BitReverse` ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.BitReverse %0 : i32 + %3 = spv.BitReverse %1 : vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> { + let summary = [{ + Result is 1 if both Operand 1 and Operand 2 are 1. Result is 0 if either + Operand 1 or Operand 2 are 0. + }]; + + let description = [{ + Results are computed per component, and within each component, per bit. + + Result Type must be a scalar or vector of integer type. The type of + Operand 1 and Operand 2 must be a scalar or vector of integer type. + They must have the same number of components as Result Type. They must + have the same component width as Result Type. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitwise-and-op ::= ssa-id `=` `spv.BitwiseAnd` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.BitwiseAnd %0, %1 : i32 + %2 = spv.BitwiseAnd %0, %1 : vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr", [Commutative]> { + let summary = [{ + Result is 1 if either Operand 1 or Operand 2 is 1. Result is 0 if both + Operand 1 and Operand 2 are 0. + }]; + + let description = [{ + Results are computed per component, and within each component, per bit. + + Result Type must be a scalar or vector of integer type. The type of + Operand 1 and Operand 2 must be a scalar or vector of integer type. + They must have the same number of components as Result Type. They must + have the same component width as Result Type. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitwise-or-op ::= ssa-id `=` `spv.BitwiseOr` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.BitwiseOr %0, %1 : i32 + %2 = spv.BitwiseOr %0, %1 : vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor", [Commutative]> { + let summary = [{ + Result is 1 if exactly one of Operand 1 or Operand 2 is 1. Result is 0 + if Operand 1 and Operand 2 have the same value. + }]; + + let description = [{ + Results are computed per component, and within each component, per bit. + + Result Type must be a scalar or vector of integer type. The type of + Operand 1 and Operand 2 must be a scalar or vector of integer type. + They must have the same number of components as Result Type. They must + have the same component width as Result Type. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + bitwise-xor-op ::= ssa-id `=` `spv.BitwiseXor` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.BitwiseXor %0, %1 : i32 + %2 = spv.BitwiseXor %0, %1 : vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical", []> { + let summary = [{ + Shift the bits in Base left by the number of bits specified in Shift. + The least-significant bits will be zero filled. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of each Base and Shift must be a scalar or vector of integer + type. Base and Shift must have the same number of components. The + number of components and bit width of the type of Base must be the same + as in Result Type. + + Shift is treated as unsigned. The result is undefined if Shift is + greater than or equal to the bit width of the components of Base. + + The number of components and bit width of Result Type must match those + Base type. All types must be integer types. + + Results are computed per component. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + shift-left-logical-op ::= ssa-id `=` `spv.ShiftLeftLogical` + ssa-use `,` ssa-use `:` + integer-scalar-vector-type `,` + integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.ShiftLeftLogical %0, %1 : i32, i16 + %5 = spv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic", []> { + let summary = [{ + Shift the bits in Base right by the number of bits specified in Shift. + The most-significant bits will be filled with the sign bit from Base. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of each Base and Shift must be a scalar or vector of integer + type. Base and Shift must have the same number of components. The + number of components and bit width of the type of Base must be the same + as in Result Type. + + Shift is treated as unsigned. The result is undefined if Shift is + greater than or equal to the bit width of the components of Base. + + Results are computed per component. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + shift-right-arithmetic-op ::= ssa-id `=` `spv.ShiftRightArithmetic` + ssa-use `,` ssa-use `:` + integer-scalar-vector-type `,` + integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.ShiftRightArithmetic %0, %1 : i32, i16 + %5 = spv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical", []> { + let summary = [{ + Shift the bits in Base right by the number of bits specified in Shift. + The most-significant bits will be zero filled. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + The type of each Base and Shift must be a scalar or vector of integer + type. Base and Shift must have the same number of components. The + number of components and bit width of the type of Base must be the same + as in Result Type. + + Shift is consumed as an unsigned integer. The result is undefined if + Shift is greater than or equal to the bit width of the components of + Base. + + Results are computed per component. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + shift-right-logical-op ::= ssa-id `=` `spv.ShiftRightLogical` + ssa-use `,` ssa-use `:` + integer-scalar-vector-type `,` + integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.ShiftRightLogical %0, %1 : i32, i16 + %5 = spv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_NotOp : SPV_BitUnaryOp<"Not", []> { + let summary = "Complement the bits of Operand."; + + let description = [{ + Results are computed per component, and within each component, per bit. + + Result Type must be a scalar or vector of integer type. + + Operand’s type must be a scalar or vector of integer type. It must + have the same number of components as Result Type. The component width + must equal the component width in Result Type. + + ### Custom assembly form + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + not-op ::= ssa-id `=` `spv.BitNot` ssa-use `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %2 = spv.Not %0 : i32 + %3 = spv.Not %1 : vector<4xi32> + ``` + }]; +} + +#endif // SPIRV_BIT_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td new file mode 100644 index 0000000000000000000000000000000000000000..99fe0bbbf5f34fb4fa285a8a6c62658947e6a939 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td @@ -0,0 +1,325 @@ +//===-- SPIRVCastOps.td - MLIR SPIR-V Cast Ops -------*- tablegen -*-------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains cast ops for the SPIR-V dialect. It corresponds +// to "3.32.11. Convertion Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_CAST_OPS +#define SPIRV_CAST_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +class SPV_CastOp traits = []> : + SPV_Op { + let arguments = (ins + SPV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; + let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let verifier = [{ return verifyCastOp(this->getOperation()); }]; +} + +// ----- + +def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> { + let summary = "Bit pattern-preserving type conversion."; + + let description = [{ + Result Type must be an OpTypePointer, or a scalar or vector of + numerical-type. + + Operand must have a type of OpTypePointer, or a scalar or vector of + numerical-type. It must be a different type than Result Type. + + If either Result Type or Operand is a pointer, the other must be a + pointer (diverges from the SPIR-V spec). + + If Result Type has a different number of components than Operand, the + total number of bits in Result Type must equal the total number of bits + in Operand. Let L be the type, either Result Type or Operand’s type, + that has the larger number of components. Let S be the other type, with + the smaller number of components. The number of components in L must be + an integer multiple of the number of components in S. The first + component (that is, the only or lowest-numbered component) of S maps to + the first components of L, and so on, up to the last component of S + mapping to the last components of L. Within this mapping, any single + component of S (mapping to multiple components of L) maps its lower- + ordered bits to the lower-numbered components of L. + + ### Custom assembly form + + ``` + bitcast-op ::= ssa-id `=` `spv.Bitcast` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.Bitcast %0 : f32 to i32 + %1 = spv.Bitcast %0 : vector<2xf32> to i64 + %1 = spv.Bitcast %0 : !spv.ptr to !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOrPtr:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOrPtr:$result + ); + + let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; + let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + + let hasCanonicalizer = 1; +} + +// ----- + +def SPV_ConvertFToSOp : SPV_CastOp<"ConvertFToS", SPV_Integer, SPV_Float, []> { + let summary = [{ + Convert value numerically from floating point to signed integer, with + round toward 0.0. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + Float Value must be a scalar or vector of floating-point type. It must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + convert-f-to-s-op ::= ssa-id `=` `spv.ConvertFToSOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertFToS %0 : f32 to i32 + %3 = spv.ConvertFToS %2 : vector<3xf32> to vector<3xi32> + ``` + }]; +} + +// ----- + +def SPV_ConvertFToUOp : SPV_CastOp<"ConvertFToU", SPV_Integer, SPV_Float, []> { + let summary = [{ + Convert value numerically from floating point to unsigned integer, with + round toward 0.0. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type, whose Signedness + operand is 0. + + Float Value must be a scalar or vector of floating-point type. It must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + convert-f-to-u-op ::= ssa-id `=` `spv.ConvertFToUOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertFToU %0 : f32 to i32 + %3 = spv.ConvertFToU %2 : vector<3xf32> to vector<3xi32> + ``` + }]; +} + +// ----- + +def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", SPV_Float, SPV_Integer, []> { + let summary = [{ + Convert value numerically from signed integer to floating point. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Signed Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + convert-s-to-f-op ::= ssa-id `=` `spv.ConvertSToFOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertSToF %0 : i32 to f32 + %3 = spv.ConvertSToF %2 : vector<3xi32> to vector<3xf32> + ``` + }]; +} + +// ----- + +def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", SPV_Float, SPV_Integer, []> { + let summary = [{ + Convert value numerically from unsigned integer to floating point. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Unsigned Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + convert-u-to-f-op ::= ssa-id `=` `spv.ConvertUToFOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertUToF %0 : i32 to f32 + %3 = spv.ConvertUToF %2 : vector<3xi32> to vector<3xf32> + ``` + }]; +} + +// ----- + +def SPV_FConvertOp : SPV_CastOp<"FConvert", SPV_Float, SPV_Float, []> { + let summary = [{ + Convert value numerically from one floating-point width to another + width. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Float Value must be a scalar or vector of floating-point type. It must + have the same number of components as Result Type. The component width + cannot equal the component width in Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + f-convert-op ::= ssa-id `=` `spv.FConvertOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.FConvertOp %0 : f32 to f64 + %3 = spv.FConvertOp %2 : vector<3xf32> to vector<3xf64> + ``` + }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; +} + +// ----- + +def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> { + let summary = [{ + Convert signed width. This is either a truncate or a sign extend. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + Signed Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. The component width + cannot equal the component width in Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + s-convert-op ::= ssa-id `=` `spv.SConvertOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.SConvertOp %0 : i32 to i64 + %3 = spv.SConvertOp %2 : vector<3xi32> to vector<3xi64> + ``` + }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; +} + +// ----- + +def SPV_UConvertOp : SPV_CastOp<"UConvert", SPV_Integer, SPV_Integer, []> { + let summary = [{ + Convert unsigned width. This is either a truncate or a zero extend. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type, whose Signedness + operand is 0. + + Unsigned Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. The component width + cannot equal the component width in Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + u-convert-op ::= ssa-id `=` `spv.UConvertOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.UConvertOp %0 : i32 to i64 + %3 = spv.UConvertOp %2 : vector<3xi32> to vector<3xi64> + ``` + }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; +} + +#endif // SPIRV_CAST_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td new file mode 100644 index 0000000000000000000000000000000000000000..5a8235fff1a3e8178caf92d891a6c75d2664efb4 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -0,0 +1,166 @@ +//===-- SPIRVCompositeOps.td - MLIR SPIR-V Composite Ops ---*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains composite ops for SPIR-V dialect. It corresponds +// to "3.32.12. Composite Instructions" of the SPIR-V spec. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_COMPOSITE_OPS +#define SPIRV_COMPOSITE_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +// ----- + +def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> { + let summary = [{ + Construct a new composite object from a set of constituent objects that + will fully form it. + }]; + + let description = [{ + Result Type must be a composite type, whose top-level + members/elements/components/columns have the same type as the types of + the operands, with one exception. The exception is that for constructing + a vector, the operands may also be vectors with the same component type + as the Result Type component type. When constructing a vector, the total + number of components in all the operands must equal the number of + components in Result Type. + + Constituents will become members of a structure, or elements of an + array, or components of a vector, or columns of a matrix. There must be + exactly one Constituent for each top-level + member/element/component/column of the result, with one exception. The + exception is that for constructing a vector, a contiguous subset of the + scalars consumed can be represented by a vector operand instead. The + Constituents must appear in the order needed by the definition of the + type of the result. When constructing a vector, there must be at least + two Constituent operands. + + ### Custom assembly form + + ``` + composite-construct-op ::= ssa-id `=` `spv.CompositeConstruct` + (ssa-use (`,` ssa-use)* )? `:` composite-type + ``` + + For example: + + ``` + %0 = spv.CompositeConstruct %1, %2, %3 : vector<3xf32> + ``` + }]; + + let arguments = (ins + Variadic:$constituents + ); + + let results = (outs + SPV_Composite:$result + ); +} + +// ----- + +def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { + let summary = "Extract a part of a composite object."; + + let description = [{ + Result Type must be the type of object selected by the last provided + index. The instruction result is the extracted object. + + Composite is the composite to extract from. + + Indexes walk the type hierarchy, potentially down to component + granularity, to select the part to extract. All indexes must be in + bounds. All composite constituents use zero-based numbering, as + described by their OpType… instruction. + + ### Custom assembly form + + ``` + composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use + `[` integer-literal (',' integer-literal)* `]` + `:` composite-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr>, Function> + %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> + %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>> + ``` + + }]; + + let arguments = (ins + SPV_Composite:$composite, + I32ArrayAttr:$indices + ); + + let results = (outs + SPV_Type:$component + ); + + let builders = [ + OpBuilder<[{Builder *builder, OperationState &state, + Value composite, ArrayRef indices}]> + ]; + + let hasFolder = 1; +} + +// ----- + +def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> { + let summary = [{ + Make a copy of a composite object, while modifying one part of it. + }]; + + let description = [{ + Result Type must be the same type as Composite. + + Object is the object to use as the modified part. + + Composite is the composite to copy all but the modified part from. + + Indexes walk the type hierarchy of Composite to the desired depth, + potentially down to component granularity, to select the part to modify. + All indexes must be in bounds. All composite constituents use zero-based + numbering, as described by their OpType… instruction. The type of the + part selected to modify must match the type of Object. + + ### Custom assembly form + + ``` + composite-insert-op ::= ssa-id `=` `spv.CompositeInsert` ssa-use, ssa-use + `[` integer-literal (',' integer-literal)* `]` + `:` object-type `into` composite-type + ``` + + For example: + + ``` + %0 = spv.CompositeInsert %object, %composite[1 : i32] : f32 into !spv.array<4xf32> + ``` + }]; + + let arguments = (ins + SPV_Type:$object, + SPV_Composite:$composite, + I32ArrayAttr:$indices + ); + + let results = (outs + SPV_Composite:$result + ); +} + +#endif // SPIRV_COMPOSITE_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td new file mode 100644 index 0000000000000000000000000000000000000000..be0955794515ed58a8e245d7ab236f0694be26da --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -0,0 +1,466 @@ +//===-- SPIRVControlFlowOps.td - SPIR-V Control Flow Ops ---*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains control flow ops for the SPIR-V dialect. It corresponds +// to "3.32.17. Control-Flow Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_CONTROLFLOW_OPS +#define SPIRV_CONTROLFLOW_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" +include "mlir/Analysis/CallInterfaces.td" + +// ----- + +def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> { + let summary = "Unconditional branch to target block."; + + let description = [{ + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` + branch-op ::= `spv.Branch` successor + successor ::= bb-id branch-use-list? + branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` + ``` + + For example: + + ``` + spv.Branch ^target + spv.Branch ^target(%0, %1: i32, f32) + ``` + }]; + + let arguments = (ins + Variadic:$block_arguments + ); + + let results = (outs); + + let builders = [ + OpBuilder< + "Builder *, OperationState &state, " + "Block *successor, ValueRange arguments = {}", [{ + state.addSuccessor(successor, arguments); + }] + > + ]; + + let skipDefaultBuilders = 1; + + let extraClassDeclaration = [{ + /// Returns the branch target block. + Block *getTarget() { return getOperation()->getSuccessor(0); } + + /// Returns the block arguments. + operand_range getBlockArguments() { + return getOperation()->getSuccessorOperands(0); + } + }]; + + let autogenSerialization = 0; +} + +// ----- + +def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", + [InFunctionScope, Terminator]> { + let summary = [{ + If Condition is true, branch to true block, otherwise branch to false + block. + }]; + + let description = [{ + Condition must be a Boolean type scalar. + + Branch weights are unsigned 32-bit integer literals. There must be + either no Branch Weights or exactly two branch weights. If present, the + first is the weight for branching to True Label, and the second is the + weight for branching to False Label. The implied probability that a + branch is taken is its weight divided by the sum of the two Branch + weights. At least one weight must be non-zero. A weight of zero does not + imply a branch is dead or permit its removal; branch weights are only + hints. The two weights must not overflow a 32-bit unsigned integer when + added together. + + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` + branch-conditional-op ::= `spv.BranchConditional` ssa-use + (`[` integer-literal, integer-literal `]`)? + `,` successor `,` successor + successor ::= bb-id branch-use-list? + branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` + ``` + + For example: + + ``` + spv.BranchConditional %condition, ^true_branch, ^false_branch + spv.BranchConditional %condition, ^true_branch(%0: i32), ^false_branch(%1: i32) + ``` + }]; + + let arguments = (ins + SPV_Bool:$condition, + Variadic:$branch_arguments, + OptionalAttr:$branch_weights + ); + + let results = (outs); + + let builders = [ + OpBuilder< + "Builder *builder, OperationState &state, Value condition, " + "Block *trueBlock, ValueRange trueArguments, " + "Block *falseBlock, ValueRange falseArguments, " + "Optional> weights = {}", + [{ + state.addOperands(condition); + state.addSuccessor(trueBlock, trueArguments); + state.addSuccessor(falseBlock, falseArguments); + if (weights) { + auto attr = + builder->getI32ArrayAttr({static_cast(weights->first), + static_cast(weights->second)}); + state.addAttribute("branch_weights", attr); + } + }] + > + ]; + + let skipDefaultBuilders = 1; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + /// Branch indices into the successor list. + enum { kTrueIndex = 0, kFalseIndex = 1 }; + + /// Returns the target block for the true branch. + Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); } + + /// Returns the target block for the false branch. + Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); } + + /// Returns the number of arguments to the true target block. + unsigned getNumTrueBlockArguments() { + return getNumSuccessorOperands(kTrueIndex); + } + + /// Returns the number of arguments to the false target block. + unsigned getNumFalseBlockArguments() { + return getNumSuccessorOperands(kFalseIndex); + } + + // Iterator and range support for true target block arguments. + operand_iterator true_block_argument_begin() { + return operand_begin() + getTrueBlockArgumentIndex(); + } + operand_iterator true_block_argument_end() { + return true_block_argument_begin() + getNumTrueBlockArguments(); + } + operand_range getTrueBlockArguments() { + return {true_block_argument_begin(), true_block_argument_end()}; + } + + // Iterator and range support for false target block arguments. + operand_iterator false_block_argument_begin() { + return true_block_argument_end(); + } + operand_iterator false_block_argument_end() { + return false_block_argument_begin() + getNumFalseBlockArguments(); + } + operand_range getFalseBlockArguments() { + return {false_block_argument_begin(), false_block_argument_end()}; + } + + private: + /// Gets the index of the first true block argument in the operand list. + unsigned getTrueBlockArgumentIndex() { + return 1; // Omit the first argument, which is the condition. + } + + /// Gets the index of the first false block argument in the operand list. + unsigned getFalseBlockArgumentIndex() { + return getTrueBlockArgumentIndex() + getNumTrueBlockArguments(); + } + }]; +} + +// ----- + +def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [ + InFunctionScope, DeclareOpInterfaceMethods]> { + let summary = "Call a function."; + + let description = [{ + Result Type is the type of the return value of the function. It must be + the same as the Return Type operand of the Function Type operand of the + Function operand. + + Function is an OpFunction instruction. This could be a forward + reference. + + Argument N is the object to copy to parameter N of Function. + + Note: A forward call is possible because there is no missing type + information: Result Type must match the Return Type of the function, and + the calling argument types must match the formal parameter types. + + ### Custom assembly form + + ``` + function-call-op ::= `spv.FunctionCall` function-id `(` ssa-use-list `)` + `:` function-type + ``` + + For example: + + ``` + spv.FunctionCall @f_void(%arg0) : (i32) -> () + %0 = spv.FunctionCall @f_iadd(%arg0, %arg1) : (i32, i32) -> i32 + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$arguments + ); + + let results = (outs + SPV_Optional:$result + ); + + let autogenSerialization = 0; +} + +// ----- + +def SPV_LoopOp : SPV_Op<"loop", [InFunctionScope]> { + let summary = "Define a structured loop."; + + let description = [{ + SPIR-V can explicitly declare structured control-flow constructs using merge + instructions. These explicitly declare a header block before the control + flow diverges and a merge block where control flow subsequently converges. + These blocks delimit constructs that must nest, and can only be entered + and exited in structured ways. See "2.11. Structured Control Flow" of the + SPIR-V spec for more details. + + Instead of having a `spv.LoopMerge` op to directly model loop merge + instruction for indicating the merge and continue target, we use regions + to delimit the boundary of the loop: the merge target is the next op + following the `spv.loop` op and the continue target is the block that + has a back-edge pointing to the entry block inside the `spv.loop`'s region. + This way it's easier to discover all blocks belonging to a construct and + it plays nicer with the MLIR system. + + The `spv.loop` region should contain at least four blocks: one entry block, + one loop header block, one loop continue block, one loop merge block. + The entry block should be the first block and it should jump to the loop + header block, which is the second block. The loop merge block should be the + last block. The merge block should only contain a `spv._merge` op. + The continue block should be the second to last block and it should have a + branch to the loop header block. The loop continue block should be the only + block, except the entry block, branching to the header block. + }]; + + let arguments = (ins + SPV_LoopControlAttr:$loop_control + ); + + let results = (outs); + + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<"Builder *builder, OperationState &state">]; + + let extraClassDeclaration = [{ + // Returns the entry block. + Block *getEntryBlock(); + + // Returns the loop header block. + Block *getHeaderBlock(); + + // Returns the loop continue block. + Block *getContinueBlock(); + + // Returns the loop merge block. + Block *getMergeBlock(); + + // Adds an empty entry block and loop merge block containing one + // spv._merge op. + void addEntryAndMergeBlock(); + }]; + + let hasOpcode = 0; + + let autogenSerialization = 0; +} + +// ----- + +def SPV_MergeOp : SPV_Op<"_merge", [Terminator]> { + let summary = "A special terminator for merging a structured selection/loop."; + + let description = [{ + We use `spv.selection`/`spv.loop` for modelling structured selection/loop. + This op is a terminator used inside their regions to mean jumping to the + merge point, which is the next op following the `spv.selection` or + `spv.loop` op. This op does not have a corresponding instruction in the + SPIR-V binary format; it's solely for structural purpose. + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; + + let hasOpcode = 0; + + let autogenSerialization = 0; +} + +// ----- + +def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> { + let summary = "Return with no value from a function with void return type."; + + let description = [{ + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` + return-op ::= `spv.Return` + ``` + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; +} + +// ----- + +def SPV_UnreachableOp : SPV_Op<"Unreachable", [InFunctionScope, Terminator]> { + let summary = "Declares that this block is not reachable in the CFG."; + + let description = [{ + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` + unreachable-op ::= `spv.Unreachable` + ``` + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; +} + +// ----- + +def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> { + let summary = "Return a value from a function."; + + let description = [{ + Value is the value returned, by copy, and must match the Return Type + operand of the OpTypeFunction type of the OpFunction body this return + instruction is in. + + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` + return-value-op ::= `spv.ReturnValue` ssa-use `:` spirv-type + ``` + + For example: + + ``` + spv.ReturnValue %0 : f32 + ``` + }]; + + let arguments = (ins + SPV_Type:$value + ); + + let results = (outs); +} + +def SPV_SelectionOp : SPV_Op<"selection", [InFunctionScope]> { + let summary = "Define a structured selection."; + + let description = [{ + SPIR-V can explicitly declare structured control-flow constructs using merge + instructions. These explicitly declare a header block before the control + flow diverges and a merge block where control flow subsequently converges. + These blocks delimit constructs that must nest, and can only be entered + and exited in structured ways. See "2.11. Structured Control Flow" of the + SPIR-V spec for more details. + + Instead of having a `spv.SelectionMerge` op to directly model selection + merge instruction for indicating the merge target, we use regions to delimit + the boundary of the selection: the merge target is the next op following the + `spv.selection` op. This way it's easier to discover all blocks belonging to + the selection and it plays nicer with the MLIR system. + + The `spv.selection` region should contain at least two blocks: one selection + header block, and one selection merge. The selection header block should be + the first block. The selection merge block should be the last block. + The merge block should only contain a `spv._merge` op. + }]; + + let arguments = (ins + SPV_SelectionControlAttr:$selection_control + ); + + let results = (outs); + + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + // Returns the selection header block. + Block *getHeaderBlock(); + + // Returns the selection merge block. + Block *getMergeBlock(); + + // Adds a selection merge block containing one spv._merge op. + void addMergeBlock(); + }]; + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let hasCanonicalizer = 1; +} + +#endif // SPIRV_CONTROLFLOW_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..0c0eebd34d1640899a8eb9bab1f2da22ea447408 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h @@ -0,0 +1,53 @@ +//===- SPIRVDialect.h - MLIR SPIR-V dialect ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the SPIR-V dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_ +#define MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_ + +#include "mlir/IR/Dialect.h" + +namespace mlir { +namespace spirv { + +enum class Decoration : uint32_t; + +class SPIRVDialect : public Dialect { +public: + explicit SPIRVDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "spv"; } + + /// Checks if the given `type` is valid in SPIR-V dialect. + static bool isValidType(Type type); + + /// Checks if the given `scalar type` is valid in SPIR-V dialect. + static bool isValidScalarType(Type type); + + /// Returns the attribute name to use when specifying decorations on results + /// of operations. + static std::string getAttributeName(Decoration decoration); + + /// Parses a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + /// Prints a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; + + /// Provides a hook for materializing a constant to this dialect. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td new file mode 100644 index 0000000000000000000000000000000000000000..b2eacbf306aea2309053b202773161bc1e33e750 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td @@ -0,0 +1,570 @@ +//===- SPIRVGLSLOps.td - GLSL extended insts spec file -----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of GLSL extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_GLSL_OPS +#define SPIRV_GLSL_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +//===----------------------------------------------------------------------===// +// SPIR-V GLSL 4.50 opcode specification. +//===----------------------------------------------------------------------===// + +// Base class for all GLSL ops. +class SPV_GLSLOp traits = []> : + SPV_ExtInstOp; + +// Base class for GLSL unary ops. +class SPV_GLSLUnaryOp traits = []> : + SPV_GLSLOp { + + let arguments = (ins + SPV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return parseUnaryOp(parser, result); }]; + + let printer = [{ return printUnaryOp(getOperation(), p); }]; + + let verifier = [{ return success(); }]; +} + +// Base class for GLSL Unary arithmetic ops where return type matches +// the operand type. +class SPV_GLSLUnaryArithmeticOp traits = []> : + SPV_GLSLUnaryOp; + +// Base class for GLSL binary ops. +class SPV_GLSLBinaryOp traits = []> : + SPV_GLSLOp { + + let arguments = (ins + SPV_ScalarOrVectorOf:$lhs, + SPV_ScalarOrVectorOf:$rhs + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + + let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + + let verifier = [{ return success(); }]; +} + +// Base class for GLSL Binary arithmetic ops where operand types and +// return type matches. +class SPV_GLSLBinaryArithmeticOp traits = []> : + SPV_GLSLBinaryOp; + +// ----- + +def SPV_GLSLFAbsOp : SPV_GLSLUnaryArithmeticOp<"FAbs", 4, SPV_Float> { + let summary = "Absolute value of operand"; + + let description = [{ + Result is x if x >= 0; otherwise result is -x. + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + abs-op ::= ssa-id `=` `spv.GLSL.FAbs` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.FAbs %0 : f32 + %3 = spv.GLSL.FAbs %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLSAbsOp : SPV_GLSLUnaryArithmeticOp<"SAbs", 5, SPV_Integer> { + let summary = "Absolute value of operand"; + + let description = [{ + Result is x if x ≥ 0; otherwise result is -x, where x is interpreted as a + signed integer. + + Result Type and the type of x must both be integer scalar or integer vector + types. Result Type and operand types must have the same number of components + with the same component width. Results are computed per component. + + ### Custom assembly format + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + abs-op ::= ssa-id `=` `spv.GLSL.SAbs` ssa-use `:` + integer-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.SAbs %0 : i32 + %3 = spv.GLSL.SAbs %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_GLSLCeilOp : SPV_GLSLUnaryArithmeticOp<"Ceil", 9, SPV_Float> { + let summary = "Rounds up to the next whole number"; + + let description = [{ + Result is the value equal to the nearest whole number that is greater than + or equal to x. + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + ceil-op ::= ssa-id `=` `spv.GLSL.Ceil` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Ceil %0 : f32 + %3 = spv.GLSL.Ceil %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLCosOp : SPV_GLSLUnaryArithmeticOp<"Cos", 14, SPV_Float16or32> { + let summary = "Cosine of operand in radians"; + + let description = [{ + The standard trigonometric cosine of x radians. + + The operand x must be a scalar or vector whose component type is 16-bit or + 32-bit floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + restricted-float-scalar-type ::= `f16` | `f32` + restricted-float-scalar-vector-type ::= + restricted-float-scalar-type | + `vector<` integer-literal `x` restricted-float-scalar-type `>` + cos-op ::= ssa-id `=` `spv.GLSL.Cos` ssa-use `:` + restricted-float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Cos %0 : f32 + %3 = spv.GLSL.Cos %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLExpOp : SPV_GLSLUnaryArithmeticOp<"Exp", 27, SPV_Float16or32> { + let summary = "Exponentiation of Operand 1"; + + let description = [{ + Result is the natural exponentiation of x; e^x. + + The operand x must be a scalar or vector whose component type is + 16-bit or 32-bit floating-point. + + Result Type and the type of x must be the same type. Results are + computed per component."; + + ### Custom assembly format + ``` + restricted-float-scalar-type ::= `f16` | `f32` + restricted-float-scalar-vector-type ::= + restricted-float-scalar-type | + `vector<` integer-literal `x` restricted-float-scalar-type `>` + exp-op ::= ssa-id `=` `spv.GLSL.Exp` ssa-use `:` + restricted-float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Exp %0 : f32 + %3 = spv.GLSL.Exp %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLFloorOp : SPV_GLSLUnaryArithmeticOp<"Floor", 8, SPV_Float> { + let summary = "Rounds down to the next whole number"; + + let description = [{ + Result is the value equal to the nearest whole number that is less than or + equal to x. + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + floor-op ::= ssa-id `=` `spv.GLSL.Floor` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Floor %0 : f32 + %3 = spv.GLSL.Floor %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLInverseSqrtOp : SPV_GLSLUnaryArithmeticOp<"InverseSqrt", 32, SPV_Float> { + let summary = "Reciprocal of sqrt(operand)"; + + let description = [{ + Result is the reciprocal of sqrt x. Result is undefined if x <= 0. + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + rsqrt-op ::= ssa-id `=` `spv.GLSL.InverseSqrt` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.InverseSqrt %0 : f32 + %3 = spv.GLSL.InverseSqrt %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLLogOp : SPV_GLSLUnaryArithmeticOp<"Log", 28, SPV_Float16or32> { + let summary = "Natural logarithm of the operand"; + + let description = [{ + Result is the natural logarithm of x, i.e., the value y which satisfies the + equation x = ey. Result is undefined if x <= 0. + + The operand x must be a scalar or vector whose component type is 16-bit or + 32-bit floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + restricted-float-scalar-type ::= `f16` | `f32` + restricted-float-scalar-vector-type ::= + restricted-float-scalar-type | + `vector<` integer-literal `x` restricted-float-scalar-type `>` + log-op ::= ssa-id `=` `spv.GLSL.Log` ssa-use `:` + restricted-float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Log %0 : f32 + %3 = spv.GLSL.Log %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLFMaxOp : SPV_GLSLBinaryArithmeticOp<"FMax", 40, SPV_Float> { + let summary = "Return maximum of two floating-point operands"; + + let description = [{ + Result is y if x < y; otherwise result is x. Which operand is the + result is undefined if one of the operands is a NaN. + + The operands must all be a scalar or vector whose component type + is floating-point. + + Result Type and the type of all operands must be the same + type. Results are computed per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmax-op ::= ssa-id `=` `spv.GLSL.FMax` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.FMax %0, %1 : f32 + %3 = spv.GLSL.FMax %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLSMaxOp : SPV_GLSLBinaryArithmeticOp<"SMax", 42, SPV_Integer> { + let summary = "Return maximum of two signed integer operands"; + + let description = [{ + Result is y if x < y; otherwise result is x, where x and y are interpreted + as signed integers. + + Result Type and the type of x and y must both be integer scalar or integer + vector types. Result Type and operand types must have the same number of + components with the same component width. Results are computed per + component. + + ### Custom assembly format + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smax-op ::= ssa-id `=` `spv.GLSL.SMax` ssa-use `:` + integer-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.SMax %0, %1 : i32 + %3 = spv.GLSL.SMax %0, %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_GLSLFMinOp : SPV_GLSLBinaryArithmeticOp<"FMin", 37, SPV_Float> { + let summary = "Return minimum of two floating-point operands"; + + let description = [{ + Result is y if y < x; otherwise result is x. Which operand is the result is + undefined if one of the operands is a NaN. + + The operands must all be a scalar or vector whose component type is + floating-point. + + Result Type and the type of all operands must be the same type. Results are + computed per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmin-op ::= ssa-id `=` `spv.GLSL.FMin` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.FMin %0, %1 : f32 + %3 = spv.GLSL.FMin %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLSMinOp : SPV_GLSLBinaryArithmeticOp<"SMin", 39, SPV_Integer> { + let summary = "Return minimum of two signed integer operands"; + + let description = [{ + Result is y if y < x; otherwise result is x, where x and y are interpreted + as signed integers. + + Result Type and the type of x and y must both be integer scalar or integer + vector types. Result Type and operand types must have the same number of + components with the same component width. Results are computed per + component. + + ### Custom assembly format + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smin-op ::= ssa-id `=` `spv.GLSL.SMin` ssa-use `:` + integer-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.SMin %0, %1 : i32 + %3 = spv.GLSL.SMin %0, %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_GLSLFSignOp : SPV_GLSLUnaryArithmeticOp<"FSign", 6, SPV_Float> { + let summary = "Returns the sign of the operand"; + + let description = [{ + Result is 1.0 if x > 0, 0.0 if x = 0, or -1.0 if x < 0. + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + sign-op ::= ssa-id `=` `spv.GLSL.FSign` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.FSign %0 : f32 + %3 = spv.GLSL.FSign %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLSSignOp : SPV_GLSLUnaryArithmeticOp<"SSign", 7, SPV_Integer> { + let summary = "Returns the sign of the operand"; + + let description = [{ + Result is 1 if x > 0, 0 if x = 0, or -1 if x < 0, where x is interpreted as + a signed integer. + + Result Type and the type of x must both be integer scalar or integer vector + types. Result Type and operand types must have the same number of components + with the same component width. Results are computed per component. + + ### Custom assembly format + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + sign-op ::= ssa-id `=` `spv.GLSL.SSign` ssa-use `:` + integer-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.SSign %0 : i32 + %3 = spv.GLSL.SSign %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_GLSLSqrtOp : SPV_GLSLUnaryArithmeticOp<"Sqrt", 31, SPV_Float> { + let summary = "Returns the square root of the operand"; + + let description = [{ + Result is the square root of x. Result is undefined if x < 0. + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + sqrt-op ::= ssa-id `=` `spv.GLSL.Sqrt` ssa-use `:` + float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Sqrt %0 : f32 + %3 = spv.GLSL.Sqrt %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLTanhOp : SPV_GLSLUnaryArithmeticOp<"Tanh", 21, SPV_Float16or32> { + let summary = "Hyperbolic tangent of operand in radians"; + + let description = [{ + Hyperbolic tangent of x radians. + + The operand x must be a scalar or vector whose component type is 16-bit or + 32-bit floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + ### Custom assembly format + ``` + restricted-float-scalar-type ::= `f16` | `f32` + restricted-float-scalar-vector-type ::= + restricted-float-scalar-type | + `vector<` integer-literal `x` restricted-float-scalar-type `>` + tanh-op ::= ssa-id `=` `spv.GLSL.Tanh` ssa-use `:` + restricted-float-scalar-vector-type + ``` + For example: + + ``` + %2 = spv.GLSL.Tanh %0 : f32 + %3 = spv.GLSL.Tanh %1 : vector<3xf16> + ``` + }]; +} + +#endif // SPIRV_GLSL_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td new file mode 100644 index 0000000000000000000000000000000000000000..827636afbafa69c3590819f18f35ae8f83f7689f --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -0,0 +1,65 @@ +//===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains group and subgroup ops for the SPIR-V dialect. It +// corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V +// specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_GROUP_OPS +#define SPIRV_GROUP_OPS + +// ----- + +def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { + let summary = "See extension SPV_KHR_shader_ballot"; + + let description = [{ + Computes a bitfield value combining the Predicate value from all invocations + in the current Subgroup that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is active + and the predicate is evaluated to true; otherwise, it is set to zero. + + Predicate must be a Boolean type. + + Result Type must be a 4 component vector of 32 bit integer types. + + Result is a set of bitfields where the first invocation is represented in bit + 0 of the first vector component and the last (up to SubgroupSize) is the + higher bit number of the last bitmask needed to represent all bits of the + subgroup invocations. + + ### Custom assembly form + + ``` + subgroup-ballot-op ::= ssa-id `=` `spv.SubgroupBallotKHR` + ssa-use `:` `vector` `<` 4 `x` `i32` `>` + ``` + + For example: + + ``` + %0 = spv.SubgroupBallotKHR %predicate : vector<4xi32> + ``` + }]; + + let arguments = (ins + SPV_Bool:$predicate + ); + + let results = (outs + SPV_I32Vec4:$result + ); + + let verifier = [{ return success(); }]; +} + +// ----- + +#endif // SPIRV_GROUP_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td new file mode 100644 index 0000000000000000000000000000000000000000..ac377d5e866612fbe0de74524f3b83b07e094cab --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -0,0 +1,991 @@ +//===-- SPIRVLogicalOps.td - MLIR SPIR-V Logical Ops -------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains arithmetic ops for the SPIR-V dialect. It corresponds +// to "3.32.15. Relational and Logical Instructions" of the SPIR-V spec. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_LOGICAL_OPS +#define SPIRV_LOGICAL_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +class SPV_LogicalBinaryOp traits = []> : + // Result type is SPV_Bool. + SPV_BinaryOp { + let parser = [{ return ::parseLogicalBinaryOp(parser, result); }]; + let printer = [{ return ::printLogicalOp(getOperation(), p); }]; +} + +class SPV_LogicalUnaryOp traits = []> : + // Result type is SPV_Bool. + SPV_UnaryOp { + let parser = [{ return ::parseLogicalUnaryOp(parser, result); }]; + let printer = [{ return ::printLogicalOp(getOperation(), p); }]; +} + +// ----- + +def SPV_FOrdEqualOp : SPV_LogicalBinaryOp<"FOrdEqual", SPV_Float, [Commutative]> { + let summary = "Floating-point comparison for being ordered and equal."; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fordequal-op ::= ssa-id `=` `spv.FOrdEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FOrdEqual %0, %1 : f32 + %5 = spv.FOrdEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FOrdGreaterThanOp : SPV_LogicalBinaryOp<"FOrdGreaterThan", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are ordered and Operand 1 is + greater than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fordgt-op ::= ssa-id `=` `spv.FOrdGreaterThan` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FOrdGreaterThan %0, %1 : f32 + %5 = spv.FOrdGreaterThan %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FOrdGreaterThanEqualOp : SPV_LogicalBinaryOp<"FOrdGreaterThanEqual", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are ordered and Operand 1 is + greater than or equal to Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fordgte-op ::= ssa-id `=` `spv.FOrdGreaterThanEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FOrdGreaterThanEqual %0, %1 : f32 + %5 = spv.FOrdGreaterThanEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FOrdLessThanOp : SPV_LogicalBinaryOp<"FOrdLessThan", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are ordered and Operand 1 is less + than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fordlt-op ::= ssa-id `=` `spv.FOrdLessThan` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FOrdLessThan %0, %1 : f32 + %5 = spv.FOrdLessThan %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FOrdLessThanEqualOp : SPV_LogicalBinaryOp<"FOrdLessThanEqual", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are ordered and Operand 1 is less + than or equal to Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fordlte-op ::= ssa-id `=` `spv.FOrdLessThanEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FOrdLessThanEqual %0, %1 : f32 + %5 = spv.FOrdLessThanEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FOrdNotEqualOp : SPV_LogicalBinaryOp<"FOrdNotEqual", SPV_Float, [Commutative]> { + let summary = "Floating-point comparison for being ordered and not equal."; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fordneq-op ::= ssa-id `=` `spv.FOrdNotEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FOrdNotEqual %0, %1 : f32 + %5 = spv.FOrdNotEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FUnordEqualOp : SPV_LogicalBinaryOp<"FUnordEqual", SPV_Float, [Commutative]> { + let summary = "Floating-point comparison for being unordered or equal."; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + funordequal-op ::= ssa-id `=` `spv.FUnordEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FUnordEqual %0, %1 : f32 + %5 = spv.FUnordEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FUnordGreaterThanOp : SPV_LogicalBinaryOp<"FUnordGreaterThan", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are unordered or Operand 1 is + greater than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + funordgt-op ::= ssa-id `=` `spv.FUnordGreaterThan` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FUnordGreaterThan %0, %1 : f32 + %5 = spv.FUnordGreaterThan %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FUnordGreaterThanEqualOp : SPV_LogicalBinaryOp<"FUnordGreaterThanEqual", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are unordered or Operand 1 is + greater than or equal to Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + funordgte-op ::= ssa-id `=` `spv.FUnordGreaterThanEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FUnordGreaterThanEqual %0, %1 : f32 + %5 = spv.FUnordGreaterThanEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FUnordLessThanOp : SPV_LogicalBinaryOp<"FUnordLessThan", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are unordered or Operand 1 is less + than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + funordlt-op ::= ssa-id `=` `spv.FUnordLessThan` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FUnordLessThan %0, %1 : f32 + %5 = spv.FUnordLessThan %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FUnordLessThanEqualOp : SPV_LogicalBinaryOp<"FUnordLessThanEqual", SPV_Float, []> { + let summary = [{ + Floating-point comparison if operands are unordered or Operand 1 is less + than or equal to Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + funordlte-op ::= ssa-id `=` `spv.FUnordLessThanEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FUnordLessThanEqual %0, %1 : f32 + %5 = spv.FUnordLessThanEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FUnordNotEqualOp : SPV_LogicalBinaryOp<"FUnordNotEqual", SPV_Float, [Commutative]> { + let summary = "Floating-point comparison for being unordered or not equal."; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + floating-point type. They must have the same type, and they must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + funordneq-op ::= ssa-id `=` `spv.FUnordNotEqual` ssa-use, ssa-use + ``` + + For example: + + ``` + %4 = spv.FUnordNotEqual %0, %1 : f32 + %5 = spv.FUnordNotEqual %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_IEqualOp : SPV_LogicalBinaryOp<"IEqual", SPV_Integer, [Commutative]> { + let summary = "Integer comparison for equality."; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + iequal-op ::= ssa-id `=` `spv.IEqual` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.IEqual %0, %1 : i32 + %5 = spv.IEqual %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual", SPV_Integer, [Commutative]> { + let summary = "Integer comparison for inequality."; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + inot-equal-op ::= ssa-id `=` `spv.INotEqual` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.INotEqual %0, %1 : i32 + %5 = spv.INotEqual %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if both Operand 1 and Operand 2 are true. Result is false + if either Operand 1 or Operand 2 are false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + logical-and ::= `spv.LogicalAnd` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalAnd %0, %1 : i1 + %2 = spv.LogicalAnd %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalEqualOp : SPV_LogicalBinaryOp<"LogicalEqual", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if Operand 1 and Operand 2 have the same value. Result is + false if Operand 1 and Operand 2 have different values. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + logical-equal ::= `spv.LogicalEqual` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalEqual %0, %1 : i1 + %2 = spv.LogicalEqual %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> { + let summary = [{ + Result is true if Operand is false. Result is false if Operand is true. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + logical-not ::= `spv.LogicalNot` ssa-use `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalNot %0 : i1 + %2 = spv.LogicalNot %0 : vector<4xi1> + ``` + }]; + + let hasCanonicalizer = 1; +} + +// ----- + +def SPV_LogicalNotEqualOp : SPV_LogicalBinaryOp<"LogicalNotEqual", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if Operand 1 and Operand 2 have different values. Result + is false if Operand 1 and Operand 2 have the same value. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + logical-not-equal ::= `spv.LogicalNotEqual` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalNotEqual %0, %1 : i1 + %2 = spv.LogicalNotEqual %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if either Operand 1 or Operand 2 is true. Result is false + if both Operand 1 and Operand 2 are false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` + logical-or ::= `spv.LogicalOr` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalOr %0, %1 : i1 + %2 = spv.LogicalOr %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, []> { + let summary = [{ + Signed-integer comparison if Operand 1 is greater than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + sgreater-than-op ::= ssa-id `=` `spv.SGreaterThan` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.SGreaterThan %0, %1 : i32 + %5 = spv.SGreaterThan %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual", SPV_Integer, []> { + let summary = [{ + Signed-integer comparison if Operand 1 is greater than or equal to + Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + sgreater-than-equal-op ::= ssa-id `=` `spv.SGreaterThanEqual` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.SGreaterThanEqual %0, %1 : i32 + %5 = spv.SGreaterThanEqual %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan", SPV_Integer, []> { + let summary = [{ + Signed-integer comparison if Operand 1 is less than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + sless-than-op ::= ssa-id `=` `spv.SLessThan` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.SLessThan %0, %1 : i32 + %5 = spv.SLessThan %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []> { + let summary = [{ + Signed-integer comparison if Operand 1 is less than or equal to Operand + 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + sless-than-equal-op ::= ssa-id `=` `spv.SLessThanEqual` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.SLessThanEqual %0, %1 : i32 + %5 = spv.SLessThanEqual %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> { + let summary = [{ + Select between two objects. Before version 1.4, results are only + computed per component. + }]; + + let description = [{ + Before version 1.4, Result Type must be a pointer, scalar, or vector. + + The types of Object 1 and Object 2 must be the same as Result Type. + + Condition must be a scalar or vector of Boolean type. + + If Condition is a scalar and true, the result is Object 1. If Condition + is a scalar and false, the result is Object 2. + + If Condition is a vector, Result Type must be a vector with the same + number of components as Condition and the result is a mix of Object 1 + and Object 2: When a component of Condition is true, the corresponding + component in the result is taken from Object 1, otherwise it is taken + from Object 2. + + ### Custom assembly form + + ``` + scalar-type ::= integer-type | float-type | boolean-type + select-object-type ::= scalar-type + | `vector<` integer-literal `x` scalar-type `>` + | pointer-type + select-condition-type ::= boolean-type + | `vector<` integer-literal `x` boolean-type `>` + select-op ::= ssa-id `=` `spv.Select` ssa-use, ssa-use, ssa-use + `:` select-condition-type `,` select-object-type + ``` + + For example: + + ``` + %3 = spv.Select %0, %1, %2 : i1, f32 + %3 = spv.Select %0, %1, %2 : i1, vector<3xi32> + %3 = spv.Select %0, %1, %2 : vector<3xi1>, vector<3xf32> + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOf:$condition, + SPV_SelectType:$true_value, + SPV_SelectType:$false_value + ); + + let results = (outs + SPV_SelectType:$result + ); + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + Value cond, Value trueValue, + Value falseValue}]>]; +} + +// ----- + +def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> { + let summary = [{ + Unsigned-integer comparison if Operand 1 is greater than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + ugreater-than-op ::= ssa-id `=` `spv.UGreaterThan` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.UGreaterThan %0, %1 : i32 + %5 = spv.UGreaterThan %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> { + let summary = [{ + Unsigned-integer comparison if Operand 1 is greater than or equal to + Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + ugreater-than-equal-op ::= ssa-id `=` `spv.UGreaterThanEqual` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.UGreaterThanEqual %0, %1 : i32 + %5 = spv.UGreaterThanEqual %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> { + let summary = [{ + Unsigned-integer comparison if Operand 1 is less than Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + uless-than-op ::= ssa-id `=` `spv.ULessThan` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.ULessThan %0, %1 : i32 + %5 = spv.ULessThan %2, %3 : vector<4xi32> + + ``` + }]; +} + +// ----- + +def SPV_ULessThanEqualOp : + SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> { + let summary = [{ + Unsigned-integer comparison if Operand 1 is less than or equal to + Operand 2. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 and Operand 2 must be a scalar or vector of + integer type. They must have the same component width, and they must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + uless-than-equal-op ::= ssa-id `=` `spv.ULessThanEqual` ssa-use, ssa-use + `:` integer-scalar-vector-type + ``` + For example: + + ``` + %4 = spv.ULessThanEqual %0, %1 : i32 + %5 = spv.ULessThanEqual %2, %3 : vector<4xi32> + + ``` + }]; +} + +#endif // SPIRV_LOGICAL_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h new file mode 100644 index 0000000000000000000000000000000000000000..0f481f5956d180eba05509ee0e9d19c27dfede36 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -0,0 +1,86 @@ +//===- SPIRVLowering.h - SPIR-V lowering utilities -------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while targeting SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H +#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Support/StringExtras.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir { + +/// Type conversion from standard types to SPIR-V types for shader interface. +/// +/// For composite types, this converter additionally performs type wrapping to +/// satisfy shader interface requirements: shader interface types must be +/// pointers to structs. +class SPIRVTypeConverter final : public TypeConverter { +public: + using TypeConverter::TypeConverter; + + /// Converts the given standard `type` to SPIR-V correspondence. + Type convertType(Type type) override; + + /// Gets the SPIR-V correspondence for the standard index type. + static Type getIndexType(MLIRContext *context); +}; + +/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. +template +class SPIRVOpLowering : public OpConversionPattern { +public: + SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), + typeConverter(typeConverter) {} + +protected: + SPIRVTypeConverter &typeConverter; +}; + +#include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc" + +namespace spirv { +/// Returns a value that represents a builtin variable value within the SPIR-V +/// module. +Value getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, + OpBuilder &builder); + +/// Attribute name for specifying argument ABI information. +StringRef getInterfaceVarABIAttrName(); + +/// Get the InterfaceVarABIAttr given its fields. +InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, + unsigned binding, + spirv::StorageClass storageClass, + MLIRContext *context); + +/// Attribute name for specifying entry point information. +StringRef getEntryPointABIAttrName(); + +/// Get the EntryPointABIAttr given its fields. +EntryPointABIAttr getEntryPointABIAttr(ArrayRef localSize, + MLIRContext *context); + +/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its +/// arguments +LogicalResult setABIAttrs(FuncOp funcOp, + spirv::EntryPointABIAttr entryPointInfo, + ArrayRef argABIInfo); + +} // namespace spirv +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVLOWERING_H diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td new file mode 100644 index 0000000000000000000000000000000000000000..91a8ff68bbf86229156aaa5cc1417f1db3e668fa --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td @@ -0,0 +1,46 @@ +//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base file for supporting lowering to SPIR-V dialect. This +// file defines SPIR-V attributes used for specifying the shader +// interface or ABI. This is because SPIR-V module is expected to work in +// an execution environment as specified by a client API. A SPIR-V module +// needs to "link" correctly with the execution environment regarding the +// resources that are used in the SPIR-V module and get populated with +// data via the client API. The shader interface (or ABI) is passed into +// SPIR-V lowering path via attributes defined in this file. A +// compilation flow targeting SPIR-V is expected to attach such +// attributes to resources and other suitable places. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_LOWERING +#define SPIRV_LOWERING + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +// For arguments that eventually map to spv.globalVariable for the +// shader interface, this attribute specifies the information regarding +// the global variable : +// 1) Descriptor Set. +// 2) Binding number. +// 3) Storage class. +def SPV_InterfaceVarABIAttr: + StructAttr<"InterfaceVarABIAttr", SPV_Dialect, + [StructFieldAttr<"descriptor_set", I32Attr>, + StructFieldAttr<"binding", I32Attr>, + StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>; + +// For entry functions, this attribute specifies information related to entry +// points in the generated SPIR-V module: +// 1) WorkGroup Size. +def SPV_EntryPointABIAttr: + StructAttr<"EntryPointABIAttr", SPV_Dialect, + [StructFieldAttr<"local_size", I32ElementsAttr>]>; + +#endif // SPIRV_LOWERING diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td new file mode 100644 index 0000000000000000000000000000000000000000..f3a9a61a9e93853587cbac2de16edf9a573728c4 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -0,0 +1,69 @@ +//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to +// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_NON_UNIFORM_OPS +#define SPIRV_NON_UNIFORM_OPS + +// ----- + +def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> { + let summary = [{ + Returns a bitfield value combining the Predicate value from all + invocations in the group that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is + active and the Predicate for that invocation evaluated to true; + otherwise, it is set to zero. + }]; + + let description = [{ + Result Type must be a vector of four components of integer type scalar, + whose Signedness operand is 0. + + Result is a set of bitfields where the first invocation is represented + in the lowest bit of the first vector component and the last (up to the + size of the group) is the higher bit number of the last bitmask needed + to represent all bits of the group invocations. + + Execution must be Workgroup or Subgroup Scope. + + Predicate must be a Boolean type. + + ### Custom assembly form + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope + ssa-use `:` `vector` `<` 4 `x` `integer-type` `>` + ``` + + For example: + + ``` + %0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32> + ``` + }]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_Bool:$predicate + ); + + let results = (outs + SPV_IntVec4:$result + ); +} + +// ----- + +#endif // SPIRV_NON_UNIFORM_OPS + diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h new file mode 100644 index 0000000000000000000000000000000000000000..2fa417bfe25cbfcf19aeab28467945ae1090642f --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -0,0 +1,41 @@ +//===- SPIRVOps.h - MLIR SPIR-V operations ----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_ +#define MLIR_DIALECT_SPIRV_SPIRVOPS_H_ + +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/IR/Function.h" + +namespace mlir { +class OpBuilder; + +namespace spirv { + +#define GET_OP_CLASSES +#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc" + +/// Following methods are auto-generated. +/// +/// Get the name used in the Op to refer to an enum value of the given +/// `EnumClass`. +/// template StringRef attributeName(); +/// +/// Get the function that can be used to symbolize an enum value. +/// template +/// Optional (*)(StringRef) symbolizeEnum(); +#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc" + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVOPS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td new file mode 100644 index 0000000000000000000000000000000000000000..1ce28928c41c55d25216e239dc609a8bdde2e2fe --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -0,0 +1,468 @@ +//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the main operation definition specification file for SPIR-V +// operations. +// +//===----------------------------------------------------------------------===// + +// Note that for each op in this file and the included files for specific op +// categories, we use a tool to automatically generate certain sections in its +// definition: basic structure, summary, description. So modifications to these +// sections will not be respected. Modifications to op traits, arguments, +// results, and sections after the results are retained. Besides, ops must be +// separated via the '// -----' marker. + +#ifndef SPIRV_OPS +#define SPIRV_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" +include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td" +include "mlir/Dialect/SPIRV/SPIRVAtomicOps.td" +include "mlir/Dialect/SPIRV/SPIRVBitOps.td" +include "mlir/Dialect/SPIRV/SPIRVCastOps.td" +include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td" +include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" +include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" +include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" +include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" +include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td" +include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" + +// ----- + +def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { + let summary = [{ + Create a pointer into a composite object that can be used with OpLoad + and OpStore. + }]; + + let description = [{ + Result Type must be an OpTypePointer. Its Type operand must be the type + reached by walking the Base’s type hierarchy down to the last provided + index in Indexes, and its Storage Class operand must be the same as the + Storage Class of Base. + + Base must be a pointer, pointing to the base of a composite object. + + Indexes walk the type hierarchy to the desired depth, potentially down + to scalar granularity. The first index in Indexes will select the top- + level member/element/component/element of the base composite. All + composite constituents use zero-based numbering, as described by their + OpType… instruction. The second index will apply similarly to that + result, and so on. Once any non-composite type is reached, there must be + no remaining (unused) indexes. + + Each index in Indexes + + - must be a scalar integer type, + + - is treated as a signed count, and + + - must be an OpConstant when indexing into a structure. + + ### Custom assembly form + ``` + access-chain-op ::= ssa-id `=` `spv.AccessChain` ssa-use + `[` ssa-use (',' ssa-use)* `]` + `:` pointer-type + ``` + + For example: + + ``` + %0 = "spv.constant"() { value = 1: i32} : () -> i32 + %1 = spv.Variable : !spv.ptr>, Function> + %2 = spv.AccessChain %1[%0] : !spv.ptr>, Function> + %3 = spv.Load "Function" %2 ["Volatile"] : !spv.array<4xf32> + ``` + }]; + + let arguments = (ins + SPV_AnyPtr:$base_ptr, + Variadic:$indices + ); + + let results = (outs + SPV_AnyPtr:$component_ptr + ); + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + Value basePtr, ValueRange indices}]>]; + + let hasCanonicalizer = 1; +} + +// ----- + +def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> { + let summary = [{ + Wait for other invocations of this module to reach the current point of + execution. + }]; + + let description = [{ + All invocations of this module within Execution scope must reach this + point of execution before any invocation will proceed beyond it. + + When Execution is Workgroup or larger, behavior is undefined if this + instruction is used in control flow that is non-uniform within + Execution. When Execution is Subgroup or Invocation, the behavior of + this instruction in non-uniform control flow is defined by the client + API. + + If Semantics is not None, this instruction also serves as an + OpMemoryBarrier instruction, and must also perform and adhere to the + description and semantics of an OpMemoryBarrier instruction with the + same Memory and Semantics operands. This allows atomically specifying + both a control barrier and a memory barrier (that is, without needing + two instructions). If Semantics is None, Memory is ignored. + + Before version 1.3, it is only valid to use this instruction with + TessellationControl, GLCompute, or Kernel execution models. There is no + such restriction starting with version 1.3. + + When used with the TessellationControl execution model, it also + implicitly synchronizes the Output Storage Class: Writes to Output + variables performed by any invocation executed prior to a + OpControlBarrier will be visible to any other invocation after return + from that OpControlBarrier. + + ### Custom assembly form + + ``` + scope ::= `"CrossDevice"` | `"Device"` | `"Workgroup"` | ... + + memory-semantics ::= `"None"` | `"Acquire"` | "Release"` | ... + + control-barrier-op ::= `spv.ControlBarrier` scope, scope, memory-semantics + ``` + + For example: + + ``` + spv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory" + + ``` + }]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$memory_semantics + ); + + let results = (outs); + + let verifier = [{ return verifyMemorySemantics(*this); }]; + + let autogenSerialization = 0; +} + +// ----- + +def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> { + let summary = "Declare an execution mode for an entry point."; + + let description = [{ + Entry Point must be the Entry Point operand of an OpEntryPoint + instruction. + + Mode is the execution mode. See Execution Mode. + + This instruction is only valid when the Mode operand is an execution + mode that takes no Extra Operands, or takes Extra Operands that are not + operands. + + ### Custom assembly form + + ``` + execution-mode ::= "Invocations" | "SpacingEqual" | + + + execution-mode-op ::= `spv.ExecutionMode ` ssa-use execution-mode + (integer-literal (`, ` integer-literal)* )? + ``` + + For example: + + ``` + spv.ExecutionMode @foo "ContractionOff" + spv.ExecutionMode @bar "LocalSizeHint", 3, 4, 5 + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + SPV_ExecutionModeAttr:$execution_mode, + I32ArrayAttr:$values + ); + + let results = (outs); + + let verifier = [{ return success(); }]; + + let autogenSerialization = 0; + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + FuncOp function, + spirv::ExecutionMode executionMode, + ArrayRef params}]>]; +} + +// ----- + +def SPV_LoadOp : SPV_Op<"Load", []> { + let summary = "Load through a pointer."; + + let description = [{ + Result Type is the type of the loaded object. It must be a type with + fixed size; i.e., it cannot be, nor include, any OpTypeRuntimeArray + types. + + Pointer is the pointer to load through. Its type must be an + OpTypePointer whose Type operand is the same as Result Type. + + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory operand + None. + + ### Custom assembly form + + ``` + memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal + | `"NonTemporal"` + + load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use + (`[` memory-access `]`)? ` : ` spirv-element-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr + %1 = spv.Load "Function" %0 : f32 + %2 = spv.Load "Function" %0 ["Volatile"] : f32 + %3 = spv.Load "Function" %0 ["Aligned", 4] : f32 + ``` + }]; + + let arguments = (ins + SPV_AnyPtr:$ptr, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let results = (outs + SPV_Type:$value + ); + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + Value basePtr, /*optional*/IntegerAttr memory_access, + /*optional*/IntegerAttr alignment}]>]; +} + +// ----- + +def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> { + let summary = "Control the order that memory accesses are observed."; + + let description = [{ + Ensures that memory accesses issued before this instruction will be + observed before memory accesses issued after this instruction. This + control is ensured only for memory accesses issued by this invocation + and observed by another invocation executing within Memory scope. If the + Vulkan memory model is declared, this ordering only applies to memory + accesses that use the NonPrivatePointer memory operand or + NonPrivateTexel image operand. + + Semantics declares what kind of memory is being controlled and what kind + of control to apply. + + To execute both a memory barrier and a control barrier, see + OpControlBarrier. + + ### Custom assembly form + + ``` + scope ::= `"CrossDevice"` | `"Device"` | `"Workgroup"` | ... + + memory-semantics ::= `"None"` | `"Acquire"` | `"Release"` | ... + + memory-barrier-op ::= `spv.MemoryBarrier` scope, memory-semantics + ``` + + For example: + + ``` + spv.MemoryBarrier "Device", "Acquire|UniformMemory" + + ``` + }]; + + let arguments = (ins + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$memory_semantics + ); + + let results = (outs); + + let verifier = [{ return verifyMemorySemantics(*this); }]; + + let autogenSerialization = 0; +} + +// ----- + +def SPV_StoreOp : SPV_Op<"Store", []> { + let summary = "Store through a pointer."; + + let description = [{ + Pointer is the pointer to store through. Its type must be an + OpTypePointer whose Type operand is the same as the type of Object. + + Object is the object to store. + + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory operand + None. + + ### Custom assembly form + + ``` + store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, ` + (`[` memory-access `]`)? `:` spirv-element-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr + %1 = spv.FMul ... : f32 + spv.Store "Function" %0, %1 : f32 + spv.Store "Function" %0, %1 ["Volatile"] : f32 + spv.Store "Function" %0, %1 ["Aligned", 4] : f32 + }]; + + let arguments = (ins + SPV_AnyPtr:$ptr, + SPV_Type:$value, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "Value ptr, Value value, ArrayRef namedAttrs", [{ + state.addOperands(ptr); + state.addOperands(value); + state.addAttributes(namedAttrs); + }]> + ]; + + let results = (outs); +} + +// ----- + +def SPV_UndefOp : SPV_Op<"undef", []> { + let summary = "Make an intermediate object whose value is undefined."; + + let description = [{ + Result Type is the type of object to make. + + Each consumption of Result yields an arbitrary, possibly different + bit pattern or abstract value resulting in possibly different concrete, + abstract, or opaque values. + + ### Custom assembly form + + ``` + undef-op ::= `spv.undef` `:` spirv-type + ``` + + For example: + + ``` + %0 = spv.undef : f32 + %1 = spv.undef : !spv.struct>> + ``` + }]; + + let arguments = (ins); + + let results = (outs + SPV_Type:$result + ); + + let verifier = [{ return success(); }]; + + let hasOpcode = 0; + let autogenSerialization = 0; +} + +// ----- + +def SPV_VariableOp : SPV_Op<"Variable", []> { + let summary = [{ + Allocate an object in memory, resulting in a pointer to it, which can be + used with OpLoad and OpStore. + }]; + + let description = [{ + Result Type must be an OpTypePointer. Its Type operand is the type of + object in memory. + + Storage Class is the Storage Class of the memory holding the object. It + cannot be Generic. It must be the same as the Storage Class operand of + the Result Type. + + Initializer is optional. If Initializer is present, it will be the + initial value of the variable’s memory content. Initializer must be an + from a constant instruction or a global (module scope) OpVariable + instruction. Initializer must have the same type as the type pointed to + by Result Type. + + ### Custom assembly form + + ``` + variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)? + (`bind(` integer-literal, integer-literal `)`)? + (`built_in(` string-literal `)`)? + attribute-dict? `:` spirv-pointer-type + ``` + + where `init` specifies initializer and `bind` specifies the + descriptor set and binding number. `built_in` specifies SPIR-V + BuiltIn decoration associated with the op. + + For example: + + ``` + %0 = spv.constant ... + + %1 = spv.Variable : !spv.ptr + %2 = spv.Variable init(%0): !spv.ptr + %3 = spv.Variable init(%0) bind(1, 2): !spv.ptr + %3 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Uniform> + ``` + }]; + + let arguments = (ins + SPV_StorageClassAttr:$storage_class, + SPV_Optional:$initializer + ); + + let results = (outs + SPV_AnyPtr:$pointer + ); +} + +// ----- + +#endif // SPIRV_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td new file mode 100644 index 0000000000000000000000000000000000000000..c37796b9f60a38ad049e185f2408ecb4aeabba18 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -0,0 +1,461 @@ +//===-- SPIRVStructureOps.td - MLIR SPIR-V Structure Ops ---*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains ops for defining the SPIR-V structure: module, function, +// and module-level operations. The representational form of these ops deviate +// from the SPIR-V binary format in order to utilize MLIR mechanisms. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_STRUCTURE_OPS +#define SPIRV_STRUCTURE_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> { + let summary = "Get the address of a global variable."; + + let description = [{ + Variables in module scope are defined using symbol names. This op generates + an SSA value that can be used to refer to the symbol within function scope + for use in ops that expect an SSA value. This operation has no corresponding + SPIR-V instruction; it's merely used for modelling purpose in the SPIR-V + dialect. Since variables in module scope in SPIR-V dialect are of pointer + type, this op returns a pointer type as well, and the type is the same as + the variable referenced. + + ### Custom assembly form + + ``` + spv-address-of-op ::= ssa-id `=` `spv._address_of` symbol-ref-id + `:` spirv-pointer-type + ``` + + For example: + + ``` + %0 = spv._address_of @global_var : !spv.ptr + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$variable + ); + + let results = (outs + SPV_AnyPtr:$pointer + ); + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + spirv::GlobalVariableOp var}]>]; +} + +def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { + let summary = "The op that declares a SPIR-V normal constant"; + + let description = [{ + This op declares a SPIR-V normal constant. SPIR-V has multiple constant + instructions covering different constant types: + + * `OpConstantTrue` and `OpConstantFalse` for boolean constants + * `OpConstant` for scalar constants + * `OpConstantComposite` for composite constants + * `OpConstantNull` for null constants + * ... + + Having such a plethora of constant instructions renders IR transformations + more tedious. Therefore, we use a single `spv.constant` op to represent + them all. Note that conversion between those SPIR-V constant instructions + and this op is purely mechanical; so it can be scoped to the binary + (de)serialization process. + + ### Custom assembly form + + ``` + spv-constant-op ::= ssa-id `=` `spv.constant` attribute-value + (`:` spirv-type)? + ``` + + For example: + + ``` + %0 = spv.constant true + %1 = spv.constant dense<[2, 3]> : vector<2xf32> + %2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>> + ``` + + TODO(antiagainst): support constant structs + }]; + + let arguments = (ins + AnyAttr:$value + ); + + let results = (outs + SPV_Type:$constant + ); + + let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns true if a constant can be built for the given `type`. + static bool isBuildableWith(Type type); + + // Creates a constant zero/one of the given `type` at the current insertion + // point of `builder` and returns it. + static spirv::ConstantOp getZero(Type type, Location loc, + OpBuilder *builder); + static spirv::ConstantOp getOne(Type type, Location loc, + OpBuilder *builder); + }]; + + let hasOpcode = 0; + + let autogenSerialization = 0; +} + +def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> { + let summary = [{ + Declare an entry point, its execution model, and its interface. + }]; + + let description = [{ + Execution Model is the execution model for the entry point and its + static call tree. See Execution Model. + + Entry Point must be the Result of an OpFunction instruction. + + Name is a name string for the entry point. A module cannot have two + OpEntryPoint instructions with the same Execution Model and the same + Name string. + + Interface is a list of symbol references to `spv.globalVariable` + operations. These declare the set of global variables from a + module that form the interface of this entry point. The set of + Interface symbols must be equal to or a superset of the + `spv.globalVariable`s referenced by the entry point’s static call + tree, within the interface’s storage classes. Before version 1.4, + the interface’s storage classes are limited to the Input and + Output storage classes. Starting with version 1.4, the interface’s + storage classes are all storage classes used in declaring all + global variables referenced by the entry point’s call tree. + + ### Custom assembly form + + ``` + execution-model ::= "Vertex" | "TesellationControl" | + + + entry-point-op ::= ssa-id `=` `spv.EntryPoint` execution-model + symbol-reference (`, ` symbol-reference)* + ``` + + For example: + + ``` + spv.EntryPoint "GLCompute" @foo + spv.EntryPoint "Kernel" @foo, @var1, @var2 + + ``` + }]; + + let arguments = (ins + SPV_ExecutionModelAttr:$execution_model, + FlatSymbolRefAttr:$fn, + SymbolRefArrayAttr:$interface + ); + + let results = (outs); + + let autogenSerialization = 0; + + let builders = [OpBuilder<[{Builder *builder, OperationState &state, + spirv::ExecutionModel executionModel, + FuncOp function, + ArrayRef interfaceVars}]>]; +} + + +def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> { + let summary = [{ + Allocate an object in memory at module scope. The object is + referenced using a symbol name. + }]; + + let description = [{ + The variable type must be an OpTypePointer. Its type operand is the type of + object in memory. + + Storage Class is the Storage Class of the memory holding the object. It + cannot be Generic. It must be the same as the Storage Class operand of + the variable types. Only those storage classes that are valid at module + scope (like Input, Output, StorageBuffer, etc.) are valid. + + Initializer is optional. If Initializer is present, it will be + the initial value of the variable’s memory content. Initializer + must be an symbol defined from a constant instruction or other + `spv.globalVariable` operation in module scope. Initializer must + have the same type as the type of the defined symbol. + + ### Custom assembly form + + ``` + variable-op ::= `spv.globalVariable` spirv-type symbol-ref-id + (`initializer(` symbol-ref-id `)`)? + (`bind(` integer-literal, integer-literal `)`)? + (`built_in(` string-literal `)`)? + attribute-dict? + ``` + + where `initializer` specifies initializer and `bind` specifies the + descriptor set and binding number. `built_in` specifies SPIR-V + BuiltIn decoration associated with the op. + + For example: + + ``` + spv.globalVariable @var0 : !spv.ptr @var0 + spv.globalVariable @var1 initializer(@var0) : !spv.ptr + spv.globalVariable @var2 bind(1, 2) : !spv.ptr + spv.globalVariable @var3 built_in("GlobalInvocationId") : !spv.ptr, Input> + ``` + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name, + OptionalAttr:$initializer + ); + + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, " + "TypeAttr type, ArrayRef namedAttrs", [{ + state.addAttribute("type", type); + state.addAttributes(namedAttrs); + }]>, + OpBuilder<[{Builder *builder, OperationState &state, + Type type, StringRef name, unsigned descriptorSet, + unsigned binding}]>, + OpBuilder<[{Builder *builder, OperationState &state, + Type type, StringRef name, spirv::BuiltIn builtin}]> + ]; + + let results = (outs); + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + ::mlir::spirv::StorageClass storageClass() { + return this->type().cast<::mlir::spirv::PointerType>().getStorageClass(); + } + }]; +} + +def SPV_ModuleOp : SPV_Op<"module", + [IsolatedFromAbove, + SingleBlockImplicitTerminator<"ModuleEndOp">, + NativeOpTrait<"SymbolTable">]> { + let summary = "The top-level op that defines a SPIR-V module"; + + let description = [{ + This op defines a SPIR-V module using a MLIR region. The region contains + one block. Module-level operations, including functions definitions, + are all placed in this block. + + Using an op with a region to define a SPIR-V module enables "embedding" + SPIR-V modules in other dialects in a clean manner: this op guarantees + the validity and serializability of a SPIR-V module and thus serves as + a clear-cut boundary. + + This op takes no operands and generates no results. This op should not + implicitly capture values from the enclosing environment. + + This op has only one region, which only contains one block. The block + must be terminated via the `spv._module_end` op. + + ### Custom assembly form + + ``` + addressing-model ::= `"Logical"` | `"Physical32"` | `"Physical64"` + memory-model ::= `"Simple"` | `"GLSL450"` | `"OpenCL"` | `"VulkanKHR"` + spv-module-op ::= `spv.module` addressing-model memory-model + region + (`attributes` attribute-dict)? + ``` + + For example: + + ``` + spv.module "Logical" "VulkanKHR" { } + + spv.module "Logical" "VulkanKHR" { + func @do_nothing() -> () { + spv.Return + } + } attributes { + capability = ["Shader"], + extension = ["SPV_KHR_16bit_storage"] + } + ``` + }]; + + let arguments = (ins + SPV_AddressingModelAttr:$addressing_model, + SPV_MemoryModelAttr:$memory_model, + OptionalAttr:$capabilities, + OptionalAttr:$extensions, + OptionalAttr:$extended_instruction_sets + ); + + let results = (outs); + + let regions = (region SizedRegion<1>:$body); + + let builders = + [OpBuilder<"Builder *, OperationState &state">, + OpBuilder<[{Builder *, OperationState &state, + IntegerAttr addressing_model, + IntegerAttr memory_model}]>, + OpBuilder<[{Builder *, OperationState &state, + spirv::AddressingModel addressing_model, + spirv::MemoryModel memory_model, + /*optional*/ ArrayRef capabilities = {}, + /*optional*/ ArrayRef extensions = {}, + /*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>]; + + // We need to ensure the block inside the region is properly terminated; + // the auto-generated builders do not guarantee that. + let skipDefaultBuilders = 1; + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + Block& getBlock() { + return this->getOperation()->getRegion(0).front(); + } + }]; +} + +def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> { + let summary = "The pseudo op that ends a SPIR-V module"; + + let description = [{ + This op terminates the only block inside a `spv.module`'s only region. + This op does not have a corresponding SPIR-V instruction and thus will + not be serialized into the binary format; it is used solely to satisfy + the structual requirement that an block must be ended with a terminator. + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; + + let verifier = [{ return success(); }]; + + let hasOpcode = 0; + + let autogenSerialization = 0; +} + +def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { + let summary = "Reference a specialization constant."; + + let description = [{ + Specialization constant in module scope are defined using symbol names. + This op generates an SSA value that can be used to refer to the symbol + within function scope for use in ops that expect an SSA value. + This operation has no corresponding SPIR-V instruction; it's merely used + for modelling purpose in the SPIR-V dialect. This op's return type is + the same as the specialization constant. + + ### Custom assembly form + + ``` + spv-reference-of-op ::= ssa-id `=` `spv._reference_of` symbol-ref-id + `:` spirv-scalar-type + ``` + + For example: + + ``` + %0 = spv._reference_of @spec_const : f32 + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$spec_const + ); + + let results = (outs + SPV_Type:$reference + ); + + let hasOpcode = 0; + + let autogenSerialization = 0; +} + +def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> { + let summary = "The op that declares a SPIR-V specialization constant"; + + let description = [{ + This op declares a SPIR-V scalar specialization constant. SPIR-V has + multiple constant instructions covering different scalar types: + + * `OpSpecConstantTrue` and `OpSpecConstantFalse` for boolean constants + * `OpSpecConstant` for scalar constants + + Similar as `spv.constant`, this op represents all of the above cases. + `OpSpecConstantComposite` and `OpSpecConstantOp` are modelled with + separate ops. + + ### Custom assembly form + + ``` + spv-spec-constant-op ::= `spv.specConstant` symbol-ref-id + `spec_id(` integer `)` + `=` attribute-value (`:` spirv-type)? + ``` + + where `spec_id` specifies the SPIR-V SpecId decoration associated with + the op. + + For example: + + ``` + spv.specConstant @spec_const1 = true + spv.specConstant @spec_const2 spec_id(5) = 42 : i32 + ``` + + TODO(antiagainst): support composite spec constants with another op + }]; + + let arguments = (ins + StrAttr:$sym_name, + AnyAttr:$default_value + ); + + let results = (outs); + + let hasOpcode = 0; + + let autogenSerialization = 0; +} + +#endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h new file mode 100644 index 0000000000000000000000000000000000000000..001d3130778402c39dae08ef9d9d573482f81762 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -0,0 +1,197 @@ +//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the types in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ +#define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ + +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +// Pull in all enum type definitions and utility function declarations +#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc" + +#include + +namespace mlir { +namespace spirv { + +namespace detail { +struct ArrayTypeStorage; +struct ImageTypeStorage; +struct PointerTypeStorage; +struct RuntimeArrayTypeStorage; +struct StructTypeStorage; +} // namespace detail + +namespace TypeKind { +enum Kind { + Array = Type::FIRST_SPIRV_TYPE, + Image, + Pointer, + RuntimeArray, + Struct, + LAST_SPIRV_TYPE = Struct, +}; +} + +// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. +class CompositeType : public Type { +public: + using Type::Type; + + static bool classof(Type type); + + unsigned getNumElements() const; + + Type getElementType(unsigned) const; +}; + +// SPIR-V array type +class ArrayType : public Type::TypeBase { +public: + using Base::Base; + // Zero layout specifies that is no layout + using LayoutInfo = uint64_t; + + static bool kindof(unsigned kind) { return kind == TypeKind::Array; } + + static ArrayType get(Type elementType, unsigned elementCount); + + static ArrayType get(Type elementType, unsigned elementCount, + LayoutInfo layoutInfo); + + unsigned getNumElements() const; + + Type getElementType() const; + + bool hasLayout() const; + + uint64_t getArrayStride() const; +}; + +// SPIR-V image type +class ImageType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Image; } + + static ImageType + get(Type elementType, Dim dim, + ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, + ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed, + ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled, + ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown, + ImageFormat format = ImageFormat::Unknown) { + return ImageType::get( + std::tuple( + elementType, dim, depth, arrayed, samplingInfo, samplerUse, + format)); + } + + static ImageType + get(std::tuple); + + Type getElementType() const; + Dim getDim() const; + ImageDepthInfo getDepthInfo() const; + ImageArrayedInfo getArrayedInfo() const; + ImageSamplingInfo getSamplingInfo() const; + ImageSamplerUseInfo getSamplerUseInfo() const; + ImageFormat getImageFormat() const; + // TODO(ravishankarm): Add support for Access qualifier +}; + +// SPIR-V pointer type +class PointerType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Pointer; } + + static PointerType get(Type pointeeType, StorageClass storageClass); + + Type getPointeeType() const; + + StorageClass getStorageClass() const; +}; + +// SPIR-V run-time array type +class RuntimeArrayType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; } + + static RuntimeArrayType get(Type elementType); + + Type getElementType() const; +}; + +// SPIR-V struct type +class StructType : public Type::TypeBase { +public: + using Base::Base; + + // Layout information used for members in a struct in SPIR-V + // + // TODO(ravishankarm) : For now this only supports the offset type, so uses + // uint64_t value to represent the offset, with + // std::numeric_limit::max indicating no offset. Change this to + // something that can hold all the information needed for different member + // types + using LayoutInfo = uint64_t; + + using MemberDecorationInfo = std::pair; + + static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } + + /// Construct a StructType with at least one member. + static StructType get(ArrayRef memberTypes, + ArrayRef layoutInfo = {}, + ArrayRef memberDecorations = {}); + + /// Construct a struct with no members. + static StructType getEmpty(MLIRContext *context); + + unsigned getNumElements() const; + + Type getElementType(unsigned) const; + + bool hasLayout() const; + + uint64_t getOffset(unsigned) const; + + // Returns in `allMemberDecorations` the spirv::Decorations (apart from + // Offset) associated with all members of the StructType. + void getMemberDecorations(SmallVectorImpl + &allMemberDecorations) const; + + // Returns in `memberDecorations` all the spirv::Decorations (apart from + // Offset) associated with the `i`-th member of the StructType. + void getMemberDecorations( + unsigned i, SmallVectorImpl &memberDecorations) const; +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h new file mode 100644 index 0000000000000000000000000000000000000000..e8240b0072e822573e17b8b4184d9f4a6cfe120d --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -0,0 +1,40 @@ +//===- Serialization.h - MLIR SPIR-V (De)serialization ----------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the entry points for serialize and deserialize SPIR-V +// binary modules. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SERIALIZATION_H_ +#define MLIR_DIALECT_SPIRV_SERIALIZATION_H_ + +#include "mlir/Support/LLVM.h" + +namespace mlir { +struct LogicalResult; +class MLIRContext; + +namespace spirv { +class ModuleOp; + +/// Serializes the given SPIR-V `module` and writes to `binary`. On failure, +/// reports errors to the error handler registered with the MLIR context for +/// `module`. +LogicalResult serialize(ModuleOp module, SmallVectorImpl &binary); + +/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp +/// in the given `context`. Returns the ModuleOp on success; otherwise, reports +/// errors to the error handler registered with `context` and returns +/// llvm::None. +Optional deserialize(ArrayRef binary, MLIRContext *context); + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SERIALIZATION_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b6534797a065cbcc148416b41ee1aaecdeb89b36 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRStandardOpsIncGen) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0ba16c56f8eb6ad012d2af2f8725f0ff1d9db02f --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -0,0 +1,342 @@ +//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines convenience types for working with standard operations +// in the MLIR operation set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARDOPS_OPS_H +#define MLIR_DIALECT_STANDARDOPS_OPS_H + +#include "mlir/Analysis/CallInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/StandardOps/OpsEnums.h.inc" + +namespace mlir { +class AffineMap; +class Builder; +class FuncOp; +class OpBuilder; + +class StandardOpsDialect : public Dialect { +public: + StandardOpsDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "std"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +/// The predicate indicates the type of the comparison to perform: +/// (un)orderedness, (in)equality and less/greater than (or equal to) as +/// well as predicates that are always true or false. +enum class CmpFPredicate { + FirstValidValue, + // Always false + AlwaysFalse = FirstValidValue, + // Ordered comparisons + OEQ, + OGT, + OGE, + OLT, + OLE, + ONE, + // Both ordered + ORD, + // Unordered comparisons + UEQ, + UGT, + UGE, + ULT, + ULE, + UNE, + // Any unordered + UNO, + // Always true + AlwaysTrue, + // Number of predicates. + NumPredicates +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/StandardOps/Ops.h.inc" + +/// This is a refinement of the "constant" op for the case where it is +/// returning a float value of FloatType. +/// +/// %1 = "std.constant"(){value: 42.0} : bf16 +/// +class ConstantFloatOp : public ConstantOp { +public: + using ConstantOp::ConstantOp; + + /// Builds a constant float op producing a float of the specified type. + static void build(Builder *builder, OperationState &result, + const APFloat &value, FloatType type); + + APFloat getValue() { return getAttrOfType("value").getValue(); } + + static bool classof(Operation *op); +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning an integer value of IntegerType. +/// +/// %1 = "std.constant"(){value: 42} : i32 +/// +class ConstantIntOp : public ConstantOp { +public: + using ConstantOp::ConstantOp; + /// Build a constant int op producing an integer of the specified width. + static void build(Builder *builder, OperationState &result, int64_t value, + unsigned width); + + /// Build a constant int op producing an integer with the specified type, + /// which must be an integer type. + static void build(Builder *builder, OperationState &result, int64_t value, + Type type); + + int64_t getValue() { return getAttrOfType("value").getInt(); } + + static bool classof(Operation *op); +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning an integer value of Index type. +/// +/// %1 = "std.constant"(){value: 99} : () -> index +/// +class ConstantIndexOp : public ConstantOp { +public: + using ConstantOp::ConstantOp; + + /// Build a constant int op producing an index. + static void build(Builder *builder, OperationState &result, int64_t value); + + int64_t getValue() { return getAttrOfType("value").getInt(); } + + static bool classof(Operation *op); +}; + +// DmaStartOp starts a non-blocking DMA operation that transfers data from a +// source memref to a destination memref. The source and destination memref need +// not be of the same dimensionality, but need to have the same elemental type. +// The operands include the source and destination memref's each followed by its +// indices, size of the data transfer in terms of the number of elements (of the +// elemental type of the memref), a tag memref with its indices, and optionally +// at the end, a stride and a number_of_elements_per_stride arguments. The tag +// location is used by a DmaWaitOp to check for completion. The indices of the +// source memref, destination memref, and the tag memref have the same +// restrictions as any load/store. The optional stride arguments should be of +// 'index' type, and specify a stride for the slower memory space (memory space +// with a lower memory space id), transferring chunks of +// number_of_elements_per_stride every stride until %num_elements are +// transferred. Either both or no stride arguments should be specified. +// +// For example, a DmaStartOp operation that transfers 256 elements of a memref +// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space +// 1 at indices [%k, %l], would be specified as follows: +// +// %num_elements = constant 256 +// %idx = constant 0 : index +// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : +// memref<40 x 128 x f32>, (d0) -> (d0), 0>, +// memref<2 x 1024 x f32>, (d0) -> (d0), 1>, +// memref<1 x i32>, (d0) -> (d0), 2> +// +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, +// %num_elt_per_stride : +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels. +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class DmaStartOp + : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState &result, Value srcMemRef, + ValueRange srcIndices, Value destMemRef, + ValueRange destIndices, Value numElements, Value tagMemRef, + ValueRange tagIndices, Value stride = nullptr, + Value elementsPerStride = nullptr); + + // Returns the source MemRefType for this DMA operation. + Value getSrcMemRef() { return getOperand(0); } + // Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() { + return getSrcMemRef()->getType().cast().getRank(); + } + // Returns the source memref indices for this DMA operation. + operand_range getSrcIndices() { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; + } + + // Returns the destination MemRefType for this DMA operations. + Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } + // Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() { + return getDstMemRef()->getType().cast().getRank(); + } + unsigned getSrcMemorySpace() { + return getSrcMemRef()->getType().cast().getMemorySpace(); + } + unsigned getDstMemorySpace() { + return getDstMemRef()->getType().cast().getMemorySpace(); + } + + // Returns the destination memref indices for this DMA operation. + operand_range getDstIndices() { + return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank()}; + } + + // Returns the number of elements being transferred by this DMA operation. + Value getNumElements() { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); + } + + // Returns the Tag MemRef for this DMA operation. + Value getTagMemRef() { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); + } + // Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + unsigned tagIndexStartPos = + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; + return {getOperation()->operand_begin() + tagIndexStartPos, + getOperation()->operand_begin() + tagIndexStartPos + + getTagMemRefRank()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower one. + bool isDestMemorySpaceFaster() { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster one. + bool isSrcMemorySpaceFaster() { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; + } + + static StringRef getOperationName() { return "std.dma_start"; } + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); + + bool isStrided() { + return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + + 1 + 1 + getTagMemRefRank(); + } + + Value getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + + Value getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } +}; + +// DmaWaitOp blocks until the completion of a DMA operation associated with the +// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index +// with the same restrictions as any load/store index. %num_elements is the +// number of elements associated with the DMA operation. For example: +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : +// memref<2048 x f32>, (d0) -> (d0), 0>, +// memref<256 x f32>, (d0) -> (d0), 1> +// memref<1 x i32>, (d0) -> (d0), 2> +// ... +// ... +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> +// +class DmaWaitOp + : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState &result, Value tagMemRef, + ValueRange tagIndices, Value numElements); + + static StringRef getOperationName() { return "std.dma_wait"; } + + // Returns the Tag MemRef associated with the DMA operation being waited on. + Value getTagMemRef() { return getOperand(0); } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getTagMemRefRank()}; + } + + // Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + // Returns the number of elements transferred in the associated DMA operation. + Value getNumElements() { return getOperand(1 + getTagMemRefRank()); } + + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); +}; + +/// Prints dimension and symbol list. +void printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, unsigned numDims, + OpAsmPrinter &p); + +/// Parses dimension and symbol list and returns true if parsing failed. +ParseResult parseDimAndSymbolList(OpAsmParser &parser, + SmallVectorImpl &operands, + unsigned &numDims); + +raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range); + +} // end namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_OPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td new file mode 100644 index 0000000000000000000000000000000000000000..1c8bb251c0298740cbfe58a40a83a3d602202cea --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -0,0 +1,1626 @@ +//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines some MLIR standard operations. +// +//===----------------------------------------------------------------------===// + +#ifndef STANDARD_OPS +#define STANDARD_OPS + +include "mlir/Analysis/CallInterfaces.td" +include "mlir/IR/OpAsmInterface.td" + +def Std_Dialect : Dialect { + let name = "std"; + let cppNamespace = ""; +} + +// Base class for Standard dialect ops. +class Std_Op traits = []> : + Op { + // For every standard op, there needs to be a: + // * void print(OpAsmPrinter &p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, + // OperationState &result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +// Base class for standard cast operations. Requires single operand and result, +// but does not constrain them to specific types. +class CastOp traits = []> : + Std_Op { + + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value source, Type destType", [{ + impl::buildCastOp(builder, result, source, destType); + }]>]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; + let verifier = [{ return ::verifyCastOp(*this); }]; + + let hasFolder = 1; +} + +// Base class for unary ops. Requires single operand and result. Individual +// classes will have `operand` accessor. +class UnaryOp traits = []> : + Op { + let results = (outs AnyType); + let printer = [{ + return printStandardUnaryOp(this->getOperation(), p); + }]; +} + +class UnaryOpSameOperandAndResultType traits = []> : + UnaryOp { + let parser = [{ + return impl::parseOneResultSameOperandTypeOp(parser, result); + }]; +} + +class FloatUnaryOp traits = []> : + UnaryOpSameOperandAndResultType, + Arguments<(ins FloatLike:$operand)>; + +// Base class for standard arithmetic operations. Requires operands and +// results to be of the same type, but does not constrain them to specific +// types. Individual classes will have `lhs` and `rhs` accessor to operands. +class ArithmeticOp traits = []> : + Op { + + let results = (outs AnyType); + + let parser = [{ + return impl::parseOneResultSameOperandTypeOp(parser, result); + }]; + + let printer = [{ + return printStandardBinaryOp(this->getOperation(), p); + }]; +} + +// Base class for standard arithmetic operations on integers, vectors and +// tensors thereof. This operation takes two operands and returns one result, +// each of these is required to be of the same type. This type may be an +// integer scalar type, a vector whose element type is an integer type, or an +// integer tensor. The custom assembly form of the operation is as follows +// +// i %0, %1 : i32 +class IntArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>; + +// Base class for standard arithmetic binary operations on floats, vectors and +// tensors thereof. This operation has two operands and returns one result, +// each of these is required to be of the same type. This type may be a +// floating point scalar type, a vector whose element type is a floating point +// type, or a floating point tensor. The custom assembly form of the operation +// is as follows +// +// f %0, %1 : f32 +class FloatArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; + +def AbsFOp : FloatUnaryOp<"absf"> { + let summary = "floating point absolute-value operation"; + let description = [{ + The `absf` operation computes the absolute value. It takes one operand and + returns one result of the same type. This type may be a float scalar type, + a vector whose element type is float, or a tensor of floats. It has no + standard attributes. + }]; +} + +def AddFOp : FloatArithmeticOp<"addf"> { + let summary = "floating point addition operation"; + let hasFolder = 1; +} + +def AddIOp : IntArithmeticOp<"addi", [Commutative]> { + let summary = "integer addition operation"; + let hasFolder = 1; +} + +def AllocOp : Std_Op<"alloc"> { + let summary = "memory allocation operation"; + let description = [{ + The "alloc" operation allocates a region of memory, as specified by its + memref type. For example: + + %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + + The optional list of dimension operands are bound to the dynamic dimensions + specified in its memref type. In the example below, the ssa value '%d' is + bound to the second dimension of the memref (which is dynamic). + + %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1> + + The optional list of symbol operands are bound to the symbols of the + memrefs affine map. In the example below, the ssa value '%s' is bound to + the symbol 's0' in the affine map specified in the allocs memref type. + + %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> + + This operation returns a single ssa value of memref type, which can be used + by subsequent load and store operations. + + The optional `alignment` attribute may be specified to ensure that the + region of memory that will be indexed is aligned at the specified byte + boundary. TODO(b/144281289) optional alignment attribute to MemRefType. + + %0 = alloc()[%s] {alignment = 8} : + memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> + }]; + + let arguments = (ins Variadic:$value, + Confined, [IntMinValue<0>]>:$alignment); + let results = (outs AnyMemRef); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, MemRefType memrefType", [{ + result.types.push_back(memrefType); + }]>, + OpBuilder< + "Builder *builder, OperationState &result, MemRefType memrefType, " # + "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ + result.addOperands(operands); + result.types.push_back(memrefType); + if (alignment) + result.addAttribute(getAlignmentAttrName(), alignment); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getAlignmentAttrName() { return "alignment"; } + + MemRefType getType() { return getResult()->getType().cast(); } + + /// Returns the number of symbolic operands (the ones in square brackets), + /// which bind to the symbols of the memref's layout map. + unsigned getNumSymbolicOperands() { + return getNumOperands() - getType().getNumDynamicDims(); + } + + /// Returns the symbolic operands (the ones in square brackets), which bind + /// to the symbols of the memref's layout map. + operand_range getSymbolicOperands() { + return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; + } + + /// Returns the dynamic sizes for this alloc operation if specified. + operand_range getDynamicSizes() { return getOperands(); } + }]; + + let hasCanonicalizer = 1; +} + +def AndOp : IntArithmeticOp<"and", [Commutative]> { + let summary = "integer binary and"; + let hasFolder = 1; +} + +def BranchOp : Std_Op<"br", [Terminator]> { + let summary = "branch operation"; + let description = [{ + The "br" operation represents a branch operation in a function. + The operation takes variable number of operands and produces no results. + The operand number and types for each successor must match the arguments of + the block successor. For example: + + ^bb2: + %2 = call @someFn() + br ^bb3(%2 : tensor<*xf32>) + ^bb3(%3: tensor<*xf32>): + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Block *dest," + "ValueRange operands = {}", [{ + result.addSuccessor(dest, operands); + }]>]; + + // BranchOp is fully verified by traits. + let verifier = ?; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; + + let hasCanonicalizer = 1; +} + +def CallOp : Std_Op<"call", [CallOpInterface]> { + let summary = "call operation"; + let description = [{ + The "call" operation represents a direct call to a function that is within + the same symbol scope as the call. The operands and result types of the + call must match the specified function type. The callee is encoded as a + function attribute named "callee". + + %2 = call @my_add(%0, %1) : (f32, f32) -> f32 + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ValueRange operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, SymbolRefAttr callee," + "ArrayRef results, ValueRange operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", callee); + result.addTypes(results); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef results, ValueRange operands = {}", [{ + build(builder, result, builder->getSymbolRefAttr(callee), results, + operands); + }]>]; + + let extraClassDeclaration = [{ + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return getAttrOfType("callee"); + } + }]; +} + +def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { + let summary = "indirect call operation"; + let description = [{ + The "call_indirect" operation represents an indirect call to a value of + function type. Functions are first class types in MLIR, and may be passed + as arguments and merged together with block arguments. The operands + and result types of the call must match the specified function type. + + %3 = call_indirect %2(%0, %1) : (f32, f32) -> f32 + }]; + + let arguments = (ins FunctionType:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value callee," + "ValueRange operands = {}", [{ + result.operands.push_back(callee); + result.addOperands(operands); + result.addTypes(callee->getType().cast().getResults()); + }]>]; + + let extraClassDeclaration = [{ + Value getCallee() { return getOperand(0); } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { return getCallee(); } + }]; + + let hasCanonicalizer = 1; +} + +def CeilFOp : FloatUnaryOp<"ceilf"> { + let summary = "ceiling of the specified value"; + let description = [{ + The `ceilf` operation computes the ceiling of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + }]; +} + +def CmpFOp : Std_Op<"cmpf", + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { + let summary = "floating-point comparison operation"; + let description = [{ + The "cmpf" operation compares its two operands according to the float + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (un)orderedness, (in)equality + and signed less/greater than (or equal to) as well as predicates that are + always true or false. The operands must have the same type, and this type + must be a float type, or a vector or tensor thereof. The result is an i1, + or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, + the operands are always treated as signed. The u prefix indicates + *unordered* comparison, not unsigned comparison, so "une" means unordered or + not equal. For the sake of readability by humans, custom assembly form for + the operation uses a string-typed attribute for the predicate. The value of + this attribute corresponds to lower-cased name of the predicate constant, + e.g., "one" means "ordered not equal". The string representation of the + attribute is merely a syntactic sugar and is converted to an integer + attribute by the parser. + + %r1 = cmpf "oeq" %0, %1 : f32 + %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> + %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 + }]; + + let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + let results = (outs BoolLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, CmpFPredicate predicate," + "Value lhs, Value rhs", [{ + ::buildCmpFOp(builder, result, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpFPredicate getPredicateByName(StringRef name); + + CmpFPredicate getPredicate() { + return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + }]; + + let hasFolder = 1; +} + +def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; +def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; +def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; +def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; +def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; +def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; +def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; +def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; +def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; +def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; + +def CmpIPredicateAttr : I64EnumAttr< + "CmpIPredicate", "", + [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, + CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { + let cppNamespace = "::mlir"; +} + +def CmpIOp : Std_Op<"cmpi", + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { + let summary = "integer comparison operation"; + let description = [{ + The "cmpi" operation compares its two operands according to the integer + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (in)equality, (un)signed + less/greater than (or equal to). The operands must have the same type, and + this type must be an integer type, a vector or a tensor thereof. The result + is an i1, or a vector/tensor thereof having the same shape as the inputs. + Since integers are signless, the predicate also explicitly indicates + whether to interpret the operands as signed or unsigned integers for + less/greater than comparisons. For the sake of readability by humans, + custom assembly form for the operation uses a string-typed attribute for + the predicate. The value of this attribute corresponds to lower-cased name + of the predicate constant, e.g., "slt" means "signed less than". The string + representation of the attribute is merely a syntactic sugar and is converted + to an integer attribute by the parser. + + %r1 = cmpi "eq" %0, %1 : i32 + %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> + %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 + }]; + + let arguments = (ins + CmpIPredicateAttr:$predicate, + IntegerLike:$lhs, + IntegerLike:$rhs + ); + let results = (outs BoolLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, CmpIPredicate predicate," + "Value lhs, Value rhs", [{ + ::buildCmpIOp(builder, result, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpIPredicate getPredicateByName(StringRef name); + + CmpIPredicate getPredicate() { + return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + }]; + + let verifier = [{ return success(); }]; + + let hasFolder = 1; +} + +def CondBranchOp : Std_Op<"cond_br", [Terminator]> { + let summary = "conditional branch operation"; + let description = [{ + The "cond_br" operation represents a conditional branch operation in a + function. The operation takes variable number of operands and produces + no results. The operand number and types for each successor must match the + arguments of the block successor. For example: + + ^bb0: + %0 = extract_element %arg0[] : tensor + cond_br %0, ^bb1, ^bb2 + ^bb1: + ... + ^bb2: + ... + }]; + + let arguments = (ins I1:$condition, Variadic:$branchOperands); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value condition," + "Block *trueDest, ValueRange trueOperands," + "Block *falseDest, ValueRange falseOperands", [{ + result.addOperands(condition); + result.addSuccessor(trueDest, trueOperands); + result.addSuccessor(falseDest, falseOperands); + }]>]; + + // CondBranchOp is fully verified by traits. + let verifier = ?; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let hasCanonicalizer = 1; +} + +def ConstantOp : Std_Op<"constant", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "constant"; + + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Attribute value", + [{ build(builder, result, value.getType(), value); }]>]; + + let extraClassDeclaration = [{ + Attribute getValue() { return getAttr("value"); } + + /// Returns true if a constant operation can be built with the given value + /// and result type. + static bool isBuildableWith(Attribute value, Type type); + }]; + + let hasFolder = 1; +} + +def CopySignOp : FloatArithmeticOp<"copysign"> { + let summary = "A copysign operation"; + let description = [{ + The `copysign` returns a value with the magnitude of the first operand and + the sign of the second operand. It takes two operands and returns one + result of the same type. This type may be a float scalar type, a vector + whose element type is float, or a tensor of floats. It has no standard + attributes. + }]; +} + +def CosOp : FloatUnaryOp<"cos"> { + let summary = "cosine of the specified value"; + let description = [{ + The `cos` operation computes the cosine of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + }]; +} + +def DeallocOp : Std_Op<"dealloc"> { + let summary = "memory deallocation operation"; + let description = [{ + The "dealloc" operation frees the region of memory referenced by a memref + which was originally created by the "alloc" operation. + The "dealloc" operation should not be called on memrefs which alias an + alloc'd memref (i.e. memrefs returned by the "view" and "reshape" + operations). + + %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + }]; + + let arguments = (ins AnyMemRef:$memref); + + let hasCanonicalizer = 1; + let hasFolder = 1; +} + +def DimOp : Std_Op<"dim", [NoSideEffect]> { + let summary = "dimension index operation"; + let description = [{ + The "dim" operation takes a memref or tensor operand and returns an "index". + It requires a single integer attribute named "index". It returns the size + of the specified dimension. For example: + + %1 = dim %0, 2 : tensor + }]; + + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor], + "any tensor or memref type">:$memrefOrTensor, + APIntAttr:$index); + let results = (outs Index); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value memrefOrTensor," + "unsigned index", [{ + auto indexType = builder->getIndexType(); + auto indexAttr = builder->getIntegerAttr(indexType, index); + build(builder, result, indexType, memrefOrTensor, indexAttr); + }]>]; + + let extraClassDeclaration = [{ + unsigned getIndex() { + return getAttrOfType("index").getValue().getZExtValue(); + } + }]; + + let hasFolder = 1; +} + +def DivFOp : FloatArithmeticOp<"divf"> { + let summary = "floating point division operation"; +} + +def SignedDivIOp : IntArithmeticOp<"divi_signed"> { + let summary = "signed integer division operation"; + let hasFolder = 1; +} + +def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> { + let summary = "unsigned integer division operation"; + let hasFolder = 1; +} + +def ExpOp : FloatUnaryOp<"exp"> { + let summary = "base-e exponential of the specified value"; +} + +def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { + let summary = "element extract operation"; + let description = [{ + The "extract_element" op reads a tensor or vector and returns one element + from it specified by an index list. The output of extract is a new value + with the same type as the elements of the tensor or vector. The arity of + indices matches the rank of the accessed value (i.e., if a tensor is of rank + 3, then 3 indices are required for the extract). The indices should all be + of index type. For example: + + %3 = extract_element %0[%1, %2] : vector<4x4xi32> + }]; + + let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, + Variadic:$indices); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value aggregate," + "ValueRange indices = {}", [{ + auto resType = aggregate->getType().cast() + .getElementType(); + build(builder, result, resType, aggregate, indices); + }]>]; + + let extraClassDeclaration = [{ + Value getAggregate() { return getOperand(0); } + + operand_range getIndices() { + return {operand_begin() + 1, operand_end()}; + } + }]; + + let hasFolder = 1; +} + +def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { + let summary = "cast between index and integer types"; + let description = [{ + Casts between integer scalars and 'index' scalars. Index is an integer of + platform-specific bit width. If casting to a wider integer, the value is + sign-extended. If casting to a narrower integer, the value is truncated. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 0; +} + +def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> { + let summary = "cast from floating-point to wider floating-point"; + let description = [{ + Cast a floating-point value to a larger floating-point-typed value. + The destination type must to be strictly wider than the source type. + Only scalars are currently supported. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 0; +} + +def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { + let summary = "cast from floating-point to narrower floating-point"; + let description = [{ + Truncate a floating-point value to a smaller floating-point-typed value. + The destination type must be strictly narrower than the source type. + If the value cannot be exactly represented, it is rounded using the default + rounding mode. Only scalars are currently supported. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 0; +} + +def LoadOp : Std_Op<"load"> { + let summary = "load operation"; + let description = [{ + The "load" op reads an element from a memref specified by an index list. The + output of load is a new value with the same type as the elements of the + memref. The arity of indices is the rank of the memref (i.e., if the memref + loaded from is of rank 3, then 3 indices are required for the load following + the memref identifier). For example: + + %3 = load %0[%1, %1] : memref<4x4xi32> + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value memref," + "ValueRange indices = {}", [{ + auto memrefType = memref->getType().cast(); + result.addOperands(memref); + result.addOperands(indices); + result.types.push_back(memrefType.getElementType()); + }]>]; + + let extraClassDeclaration = [{ + Value getMemRef() { return getOperand(0); } + void setMemRef(Value value) { setOperand(0, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } + }]; + + let hasFolder = 1; +} + +def LogOp : FloatUnaryOp<"log"> { + let summary = "base-e logarithm of the specified value"; +} + +def Log10Op : FloatUnaryOp<"log10"> { + let summary = "base-10 logarithm of the specified value"; +} + +def Log2Op : FloatUnaryOp<"log2"> { + let summary = "base-2 logarithm of the specified value"; +} + +def MemRefCastOp : CastOp<"memref_cast"> { + let summary = "memref cast operation"; + let description = [{ + The "memref_cast" operation converts a memref from one type to an equivalent + type with a compatible shape. The source and destination types are + compatible if: + a. both are ranked memref types with the same element type, affine mappings, + address space, and rank but where the individual dimensions may add or + remove constant dimensions from the memref type. + + If the cast converts any dimensions from an unknown to a known size, then it + acts as an assertion that fails at runtime of the dynamic dimensions + disagree with resultant destination size. + + Example: + Assert that the input dynamic shape matches the destination static shape. + %2 = memref_cast %1 : memref to memref<4x4xf32> + Erase static shape information, replacing it with dynamic information. + %3 = memref_cast %1 : memref<4xf32> to memref + + The same holds true for offsets and strides. + + Assert that the input dynamic shape matches the destination static stride. + %4 = memref_cast %1 : memref<12x4xf32, offset:?, strides: [?, ?]> to + memref<12x4xf32, offset:5, strides: [4, 1]> + Erase static offset and stride information, replacing it with + dynamic information. + %5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to + memref<12x4xf32, offset:?, strides: [?, ?]> + + b. either or both memref types are unranked with the same element type, and + address space. + + Example: + Cast to concrete shape. + %4 = memref_cast %1 : memref<*xf32> to memref<4x?xf32> + + Erase rank information. + %5 = memref_cast %1 : memref<4x?xf32> to memref<*xf32> + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$source); + let results = (outs AnyRankedOrUnrankedMemRef); + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a memref_cast is always a memref. + Type getType() { return getResult()->getType(); } + }]; +} + +def MulFOp : FloatArithmeticOp<"mulf"> { + let summary = "floating point multiplication operation"; + let hasFolder = 1; +} + +def MulIOp : IntArithmeticOp<"muli", [Commutative]> { + let summary = "integer multiplication operation"; + let hasFolder = 1; +} + +def NegFOp : FloatUnaryOp<"negf"> { + let summary = "floating point negation"; + let description = [{ + The `negf` operation computes the negation of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + }]; +} + +def OrOp : IntArithmeticOp<"or", [Commutative]> { + let summary = "integer binary or"; + let hasFolder = 1; +} + +def PrefetchOp : Std_Op<"prefetch"> { + let summary = "prefetch operation"; + let description = [{ + The "prefetch" op prefetches data from a memref location described with + subscript indices similar to std.load, and with three attributes: a + read/write specifier, a locality hint, and a cache type specifier as shown + below: + + prefetch %0[%i, %j], read, locality<3>, data : memref<400x400xi32> + + The read/write specifier is either 'read' or 'write', the locality hint + ranges from locality<0> (no locality) to locality<3> (extremely local keep + in cache). The cache type specifier is either 'data' or 'instr' + and specifies whether the prefetch is performed on data cache or on + instruction cache. + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices, + BoolAttr:$isWrite, + Confined, + IntMaxValue<3>]>:$localityHint, + BoolAttr:$isDataCache); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value memref," + "ArrayRef indices, bool isWrite, unsigned hint, bool isData", + [{ + auto hintAttr = builder->getI32IntegerAttr(hint); + auto isWriteAttr = builder->getBoolAttr(isWrite); + auto isDataCacheAttr = builder->getBoolAttr(isData); + result.addOperands(memref); + result.addOperands(indices); + result.addAttribute("localityHint", hintAttr); + result.addAttribute("isWrite", isWriteAttr); + result.addAttribute("isDataCache", isDataCacheAttr); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref()->getType().cast(); + } + static StringRef getLocalityHintAttrName() { return "localityHint"; } + static StringRef getIsWriteAttrName() { return "isWrite"; } + static StringRef getIsDataCacheAttrName() { return "isDataCache"; } + }]; + + let hasFolder = 1; +} + +def RankOp : Std_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The "rank" operation takes a tensor operand and returns its rank. + + %1 = rank %0 : index + }]; + + let arguments = (ins AnyTensor); + let results = (outs Index); + let verifier = ?; + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value tensor", [{ + auto indexType = builder->getIndexType(); + build(builder, result, indexType, tensor); + }]>]; + + let hasFolder = 1; +} + +def RemFOp : FloatArithmeticOp<"remf"> { + let summary = "floating point division remainder operation"; +} + +def SignedRemIOp : IntArithmeticOp<"remi_signed"> { + let summary = "signed integer division remainder operation"; + let hasFolder = 1; +} + +def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> { + let summary = "unsigned integer division remainder operation"; + let hasFolder = 1; +} + +def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. For example: + + func @foo() : (i32, f8) { + ... + return %0, %1 : i32, f8 + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; +} + +def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "select operation"; + let description = [{ + The "select" operation chooses one value based on a binary condition + supplied as its first operand. If the value of the first operand is 1, the + second operand is chosen, otherwise the third operand is chosen. The second + and the third operand must have the same type. The operation applies + elementwise to vectors and tensors. The shape of all arguments must be + identical. For example, the maximum operation is obtained by combining + "select" with "cmpi" as follows. + + %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 + %3 = select %2, %0, %1 : i32 + }]; + + let arguments = (ins BoolLike:$condition, AnyType:$true_value, + AnyType:$false_value); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value condition," + "Value trueValue, Value falseValue", [{ + result.addOperands({condition, trueValue, falseValue}); + result.addTypes(trueValue->getType()); + }]>]; + + let extraClassDeclaration = [{ + Value getCondition() { return condition(); } + Value getTrueValue() { return true_value(); } + Value getFalseValue() { return false_value(); } + }]; + + let hasFolder = 1; +} + +def SignExtendIOp : Std_Op<"sexti", + [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "integer sign extension operation"; + let description = [{ + The integer sign extension operation takes an integer input of + width M and an integer destination type of width N. The destination + bit-width must be larger than the input bit-width (N > M). + The top-most (N - M) bits of the output are filled with copies + of the most-significant bit of the input. + + %1 = constant 5 : i3 // %1 is 0b101 + %2 = sexti %1 : i3 to i6 // %2 is 0b111101 + %3 = constant 2 : i3 // %3 is 0b010 + %4 = sexti %3 : i3 to i6 // %4 is 0b000010 + + %5 = sexti %0 : vector<2 x i32> to vector<2 x i64> + }]; + + let arguments = (ins IntegerLike:$value); + let results = (outs IntegerLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value value, Type destType", [{ + result.addOperands(value); + result.addTypes(destType); + }]>]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; +} + +def ShiftLeftOp : IntArithmeticOp<"shift_left"> { + let summary = "integer left-shift"; + let description = [{ + The shift_left operation shifts an integer value to the left by a variable + amount. The low order bits are filled with zeros. + + %1 = constant 5 : i8 // %1 is 0b00000101 + %2 = constant 3 : i8 + %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 + }]; +} + +def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { + let summary = "signed integer right-shift"; + let description = [{ + The shift_right_signed operation shifts an integer value to the right by + a variable amount. The integer is interpreted as signed. The high order + bits in the output are filled with copies of the most-significant bit + of the shifted value (which means that the sign of the value is preserved). + + %1 = constant 160 : i8 // %1 is 0b10100000 + %2 = constant 3 : i8 + %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 + %4 = constant 96 : i8 // %4 is 0b01100000 + %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 + }]; +} + +def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> { + let summary = "unsigned integer right-shift"; + let description = [{ + The shift_right_unsigned operation shifts an integer value to the right by + a variable amount. The integer is interpreted as unsigned. The high order + bits are always filled with zeros. + + %1 = constant 160 : i8 // %1 is 0b10100000 + %2 = constant 3 : i8 + %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 + }]; +} + +def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { + let summary = "cast from integer type to floating-point"; + let description = [{ + Cast from a value interpreted as signed integer to the corresponding + floating-point value. If the value cannot be exactly represented, it is + rounded using the default rounding mode. Only scalars are currently + supported. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 0; +} + +def SplatOp : Std_Op<"splat", [NoSideEffect]> { + let summary = "splat or broadcast operation"; + let description = [{ + The "splat" op reads a value of integer or float type and broadcasts it into + a vector or a tensor. The output of splat is thus a new value of either + vector or tensor type with elemental type being its operand's type. + When the result is a tensor, it has to be statically shaped. + + %1 = splat %0 : vector<8xi32> + %2 = splat %0 : tensor<4x8xi32> + + TODO: Extend this operation to broadcast to dynamically shaped tensors in + the same way dynamically shaped memrefs are handled. + + // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding + // to the sizes of the two dynamic dimensions. + + %m = "foo"() : () -> (index) + %n = "bar"() : () -> (index) + %t = splat %s [%m, %n] : tensor + + }]; + + let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat], + "integer or float type">:$input); + let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); + + let builders = + [OpBuilder<"Builder *builder, OperationState &result, Value element, " + "Type aggregateType", + [{ build(builder, result, aggregateType, element); }]>]; + + let hasFolder = 1; +} + +def StoreOp : Std_Op<"store"> { + let summary = "store operation"; + let description = [{ + The "store" op writes an element to a memref specified by an index list. + The arity of indices is the rank of the memref (i.e. if the memref being + stored to is of rank 3, then 3 indices are required for the store following + the memref identifier). The store operation does not produce a result. + + In the following example, the ssa value '%v' is stored in memref '%A' at + indices [%i, %j]: + store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> + }]; + + let arguments = (ins AnyType:$value, AnyMemRef:$memref, + Variadic:$indices); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value valueToStore, Value memref", [{ + result.addOperands(valueToStore); + result.addOperands(memref); + }]>]; + + let extraClassDeclaration = [{ + Value getValueToStore() { return getOperand(0); } + + Value getMemRef() { return getOperand(1); } + void setMemRef(Value value) { setOperand(1, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + operand_range getIndices() { + return {operand_begin() + 2, operand_end()}; + } + }]; + + let hasFolder = 1; +} + +def SubFOp : FloatArithmeticOp<"subf"> { + let summary = "floating point subtraction operation"; + let hasFolder = 1; +} + +def SubIOp : IntArithmeticOp<"subi"> { + let summary = "integer subtraction operation"; + let hasFolder = 1; +} + +def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { + let summary = "memref subview operation"; + let description = [{ + The "subview" operation converts a memref type to another memref type + which represents a reduced-size view of the original memref as specified by + the operation's offsets, sizes and strides arguments. + + The SubView operation supports the following arguments: + *) Memref: the "base" memref on which to create a "view" memref. + *) Offsets: zero or memref-rank number of dynamic offsets into the "base" + memref at which to create the "view" memref. + *) Sizes: zero or memref-rank dynamic size operands which specify the + dynamic sizes of the result "view" memref type. + *) Strides: zero or memref-rank number of dynamic strides which are applied + multiplicatively to the base memref strides in each dimension. + + Note on the number of operands for offsets, sizes and strides: For + each of these, the number of operands must either be same as the + memref-rank number or empty. For the latter, those values will be + treated as constants. + + Example 1: + + %0 = alloc() : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> + + // Create a sub-view of "base" memref '%0' with offset arguments '%c0', + // dynamic sizes for each dimension, and stride arguments '%c1'. + %1 = subview %0[%c0, %c0][%size0, %size1][%c1, %c1] + : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1) > to + memref (d0 * s1 + d1 + s0)> + + Example 2: + + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> + + // Create a sub-view of "base" memref '%0' with dynamic offsets, sizes, + // and strides. + // Note that dynamic offsets are represented by the linearized dynamic + // offset symbol 's0' in the subview memref layout map, and that the + // dynamic strides operands, after being applied to the base memref + // strides in each dimension, are represented in the view memref layout + // map as symbols 's1', 's2' and 's3'. + %1 = subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + + Example 3: + + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> + + // Subview with constant offsets, sizes and strides. + %1 = subview %0[][][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> + + Example 4: + + %0 = alloc(%arg0, %arg1) : memref + + // Subview with constant size, but dynamic offsets and + // strides. The resulting memref has a static shape, but if the + // base memref has an affine map to describe the layout, the result + // memref also uses an affine map to describe the layout. The + // strides of the result memref is computed as follows: + // + // Let #map1 represents the layout of the base memref, and #map2 + // represents the layout of the result memref. A #mapsubview can be + // constructed to map an index from the result memref to the base + // memref (note that the description below uses more convenient + // naming for symbols, while in affine maps, symbols are + // represented as unsigned numbers that identify that symbol in the + // given affine map. + // + // #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1) + // + // where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then, + // + // #map2 = #map1.compose(#mapsubview) + // + // If the layout map is represented as + // + // #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) + // + // then, + // + // #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] -> + // (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0) + // + // Representing this canonically + // + // #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) + // + // where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. + %1 = subview %0[%i, %j][][%x, %y] : + : memref (d0 * s1 + d1 * s2 + s0)> to + memref<4x4xf32, (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)> + + // Note that the subview op does not guarantee that the result + // memref is "inbounds" w.r.t to base memref. It is upto the client + // to ensure that the subview is accessed in a manner that is + // in-bounds. + + } + }]; + + // TODO(b/144779634, ravishankarm) : Use different arguments for + // offsets, sizes and strides. + let arguments = (ins + AnyMemRef:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I32ElementsAttr:$operand_segment_sizes + ); + let results = (outs AnyMemRef); + + let builders = [ + OpBuilder< + "Builder *b, OperationState &result, Value source, " + "ValueRange offsets, ValueRange sizes, " + "ValueRange strides, Type resultType = Type(), " + "ArrayRef attrs = {}">, + OpBuilder< + "Builder *builder, OperationState &result, " + "Type resultType, Value source"> + ]; + + let extraClassDeclaration = [{ + /// Returns the type of the base memref operand. + MemRefType getBaseMemRefType() { + return source()->getType().cast(); + } + + /// The result of a subview is always a memref. + MemRefType getType() { return getResult()->getType().cast(); } + + /// Returns as integer value the number of offset operands. + int64_t getNumOffsets() { return llvm::size(offsets()); } + + /// Returns as integer value the number of size operands. + int64_t getNumSizes() { return llvm::size(sizes()); } + + /// Returns as integer value the number of stride operands. + int64_t getNumStrides() { return llvm::size(strides()); } + + /// Returns the dynamic sizes for this subview operation if specified. + operand_range getDynamicSizes() { return sizes(); } + + /// Returns in `staticStrides` the static value of the stride + /// operands. Returns failure() if the static value of the stride + /// operands could not be retrieved. + LogicalResult getStaticStrides(SmallVectorImpl &staticStrides); + + // Auxiliary range data structure and helper function that unpacks the + // offset, size and stride operands of the SubViewOp into a list of triples. + // Such a list of triple is sometimes more convenient to manipulate. + struct Range { + Value offset, size, stride; + }; + SmallVector getRanges(); + }]; + + let hasCanonicalizer = 1; +} + +def TanhOp : FloatUnaryOp<"tanh"> { + let summary = "hyperbolic tangent of the specified value"; + let description = [{ + The `tanh` operation computes the hyperbolic tangent. It takes one operand + and returns one result of the same type. This type may be a float scalar + type, a vector whose element type is float, or a tensor of floats. It has + no standard attributes. + }]; +} + +def TensorCastOp : CastOp<"tensor_cast"> { + let summary = "tensor cast operation"; + let description = [{ + The "tensor_cast" operation converts a tensor from one type to an equivalent + type without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + + Convert from unknown rank to rank 2 with unknown dimension sizes. + %2 = tensor_cast %1 : tensor<*xf32> to tensor + }]; + + let arguments = (ins AnyTensor); + let results = (outs AnyTensor); + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a tensor_cast is always a tensor. + TensorType getType() { return getResult()->getType().cast(); } + }]; +} + +def TensorLoadOp : Std_Op<"tensor_load", + [SameOperandsAndResultShape, SameOperandsAndResultElementType]> { + let summary = "tensor load operation"; + let description = [{ + The "tensor_load" operation creates a tensor from a memref, making an + independent copy of the element data. The result value is a tensor whose + shape and element type match the memref operand. + + Produce a value of tensor<4x?xf32> type. + %12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0> + }]; + + let arguments = (ins AnyMemRef); + let results = (outs AnyTensor); + // TensorLoadOp is fully verified by traits. + let verifier = ?; + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value memref", [{ + auto memrefType = memref->getType().cast(); + auto resultType = RankedTensorType::get(memrefType.getShape(), + memrefType.getElementType()); + result.addOperands(memref); + result.addTypes(resultType); + }]>]; + + + let extraClassDeclaration = [{ + /// The result of a tensor_load is always a tensor. + TensorType getType() { return getResult()->getType().cast(); } + }]; +} + +def TensorStoreOp : Std_Op<"tensor_store", + [SameOperandsShape, SameOperandsElementType]> { + let summary = "tensor store operation"; + let description = [{ + The "tensor_store" operation stores the contents of a tensor into a memref. + The first operand is a value of tensor type, the second operand is a value + of memref type. The shapes and element types of these must match, and are + specified by the memref type. + + Example: + %9 = dim %8, 1 : tensor<4x?xf32> + %10 = alloc(%9) : memref<4x?xf32, #layout, memspace0> + tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0> + }]; + + let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref); + // TensorStoreOp is fully verified by traits. + let verifier = ?; +} + +def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "integer truncation operation"; + let description = [{ + The integer truncation operation takes an integer input of + width M and an integer destination type of width N. The destination + bit-width must be smaller than the input bit-width (N < M). + The top-most (N - M) bits of the input are discarded. + + %1 = constant 21 : i5 // %1 is 0b10101 + %2 = trunci %1 : i5 to i4 // %2 is 0b0101 + %3 = trunci %1 : i5 to i3 // %3 is 0b101 + + %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> + }]; + + let arguments = (ins IntegerLike:$value); + let results = (outs IntegerLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value value, Type destType", [{ + result.addOperands(value); + result.addTypes(destType); + }]>]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; +} + +def ViewOp : Std_Op<"view", [NoSideEffect]> { + let summary = "memref view operation"; + let description = [{ + The "view" operation converts a 1-D memref with i8 element type, + to an N-D memref with arbitrary element type. In addition, the ViewOp + supports the following arguments: + *) A single dynamic offset operand can be specified which represents a + a dynamic offset within the base 1-D memref at which to create the + resulting memref view. + *) A dynamic size operand must be specified for each dynamic dimension + in the resulting view memref type. + + // Allocate a flat 1D/i8 memref. + %0 = alloc() : memref<2048xi8> + + // ViewOp with static offset and sizes. + %1 = view %0[][] : memref<2048xi8> to memref<64x4xf32> + + // ViewOp with dynamic offset and one dynamic size. + %2 = view %0[%offset_1024][%size0] + : memref<2048xi8> to memref (d0 * 4 + d1 + s0)> + + // ViewOp creating 3D shape where two of the dim sizes are dynamic. + // *) The dynamic offset specified in the ViewOp is applied to the + // base 1-D memref, and is represented by the symbol 's0' in the + // layout map of the ViewOp result memref type. + // *) The dynamic size for the second dimension induces a dynamic + // stride for the first dimension, which is represented by the + // symbol 's1' in the layout map of the ViewOp result memref type. + // Note that this dynamic stride will be computed from the view + // shape and dynamic sizes. + %3 = view %0[%offset_1024][%size0, %size1] + : memref<2048xi8> to memref (d0 * s1 + d1 * 4 + d2 + s0)> + }]; + + let arguments = (ins MemRefRankOf<[I8], [1]>:$source, + Variadic:$operands); + let results = (outs AnyMemRef); + + let extraClassDeclaration = [{ + /// The result of a view is always a memref. + MemRefType getType() { return getResult()->getType().cast(); } + + /// Returns the dynamic offset for this view operation if specified. + /// Returns nullptr if no dynamic offset was specified. + Value getDynamicOffset(); + + /// Returns the starting operand list position of the dynamic size operands. + unsigned getDynamicSizesOperandStart() { + return getDynamicOffset() == nullptr ? 1 : 2; + } + + /// Returns the dynamic sizes for this view operation. + operand_range getDynamicSizes() { + return {operand_begin() + getDynamicSizesOperandStart(), operand_end()}; + } + }]; + + let hasCanonicalizer = 1; +} + +def XOrOp : IntArithmeticOp<"xor", [Commutative]> { + let summary = "integer binary xor"; + let hasFolder = 1; +} + +def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "integer zero extension operation"; + let description = [{ + The integer zero extension operation takes an integer input of + width M and an integer destination type of width N. The destination + bit-width must be larger than the input bit-width (N > M). + The top-most (N - M) bits of the output are filled with zeros. + + %1 = constant 5 : i3 // %1 is 0b101 + %2 = zexti %1 : i3 to i6 // %2 is 0b000101 + %3 = constant 2 : i3 // %3 is 0b010 + %4 = zexti %3 : i3 to i6 // %4 is 0b000010 + + %5 = zexti %0 : vector<2 x i32> to vector<2 x i64> + }]; + + let arguments = (ins IntegerLike:$value); + let results = (outs IntegerLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value value, Type destType", [{ + result.addOperands(value); + result.addTypes(destType); + }]>]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; +} + +#endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h new file mode 100644 index 0000000000000000000000000000000000000000..87c8e662a65521eeb187d11d0f8df54016866114 --- /dev/null +++ b/mlir/include/mlir/Dialect/Traits.h @@ -0,0 +1,80 @@ +//===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares common op traits that are not core to MLIR but can be +// shared by multiple dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRAITS +#define MLIR_DIALECT_TRAITS + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +LogicalResult verifyCompatibleOperandBroadcast(Operation *op); +} // namespace impl + +namespace util { +/// Returns true and sets `resultShape` to the broadcasted shape from the two +/// given shapes if they are broadcast compatible. Returns false and clears +/// `resultShape` otherwise. +/// +/// The rules for determining the result shape are: +/// +/// Zip together the dimensions in the two given shapes by prepending the shape +/// with less dimensions with 1s. For each dimension pair, deduces the result +/// dimension according to the following order: +/// - If there are unknown dimensions, follows the TensorFlow behavior: +/// - If either dimension is greater than 1, we assume that the program is +/// correct, and the other dimension will be broadcast to match it. +/// - If either dimension is 1, the other dimension is the result. +/// - Otherwise, the result dimension is unknown dimension. +/// - If one of the dimension is 1, the other dimension is the result. +/// - If two dimensions are the same, that's the result. +/// - Otherwise, incompatible shape. +bool getBroadcastedShape(ArrayRef shape1, ArrayRef shape2, + SmallVectorImpl &resultShape); + +/// Returns the result broadcast composition type from the two given types by +/// following NumPy broadcast semantics. Returned type may have dynamic shape if +/// either of the input types has dynamic shape. Returns null type if the two +/// given types are not broadcast-compatible. +Type getBroadcastedType(Type type1, Type type2); +} // namespace util + +/// This class provides the API for ops that are known to have broadcast- +/// compatible operand and result types. Specifically, starting from the +/// most varying dimension, each dimension pair of the two operands' types +/// should either be the same or one of them is one. Also, the result type +/// should have the corresponding dimension equal to the larger one, if known. +/// Shapes are checked partially if ranks or dimensions are not known. For +/// example, an op with tensor and tensor <2 x f32> as operand +/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible. +/// +/// Ths trait assumes the op has two operands and one result, and it asserts +/// if the pre-condition is not satisfied. +template +class BroadcastableTwoOperandsOneResult + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyCompatibleOperandBroadcast(op); + } +}; + +} // end namespace OpTrait +} // end namespace mlir + +#endif // MLIR_DIALECT_TRAITS diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..9e7cbba0f433996c35504c00148d524645c025eb --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -0,0 +1,105 @@ +//===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file define utilities that operate on standard types and are +// useful across multiple dialects that use structured ops abstractions. These +// abstractions consist of define custom operations that encode and transport +// information about their semantics (e.g. type of iterators like parallel, +// reduction, etc..) as attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H +#define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +/// Attribute name for the AffineArrayAttr which encodes the relationship +/// between a structured op iterators' and its operands. +static constexpr StringLiteral getIndexingMapsAttrName() { + return StringLiteral("indexing_maps"); +} + +/// Attribute name for the StrArrayAttr which encodes the type of a structured +/// op's iterators. +static constexpr StringLiteral getIteratorTypesAttrName() { + return StringLiteral("iterator_types"); +} + +/// Attribute name for the IntegerAttr which encodes the number of input buffer +/// arguments. +static constexpr StringLiteral getArgsInAttrName() { + return StringLiteral("args_in"); +} + +/// Attribute name for the IntegerAttr which encodes the number of input buffer +/// arguments. +static constexpr StringLiteral getArgsOutAttrName() { + return StringLiteral("args_out"); +} + +/// Attribute name for the StringAttr which encodes an optional documentation +/// string of the structured op. +static constexpr StringLiteral getDocAttrName() { return StringLiteral("doc"); } + +/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the +/// MLIR function that implements the body of the structured op. +static constexpr StringLiteral getFunAttrName() { return StringLiteral("fun"); } + +/// Attribute name for the StrArrayAttr which encodes the external library +/// function that implements the structured op. +static constexpr StringLiteral getLibraryCallAttrName() { + return StringLiteral("library_call"); +} + +/// Use to encode that a particular iterator type has parallel semantics. +inline static constexpr StringLiteral getParallelIteratorTypeName() { + return StringLiteral("parallel"); +} + +/// Use to encode that a particular iterator type has reduction semantics. +inline static constexpr StringLiteral getReductionIteratorTypeName() { + return StringLiteral("reduction"); +} + +/// Use to encode that a particular iterator type has window semantics. +inline static constexpr StringLiteral getWindowIteratorTypeName() { + return StringLiteral("window"); +} + +/// Use to encode that a particular iterator type has window semantics. +inline static ArrayRef getAllIteratorTypeNames() { + static StringRef names[3] = {getParallelIteratorTypeName(), + getReductionIteratorTypeName(), + getWindowIteratorTypeName()}; + return llvm::makeArrayRef(names); +} + +/// Returns the iterator of a certain type. +inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) { + auto names = getAllIteratorTypeNames(); + (void)names; + assert(llvm::is_contained(names, name)); + return llvm::count_if(iteratorTypes, [name](Attribute a) { + return a.cast().getValue() == name; + }); +} + +inline unsigned getNumIterators(ArrayAttr iteratorTypes) { + unsigned res = 0; + for (auto n : getAllIteratorTypeNames()) + res += getNumIterators(n, iteratorTypes); + return res; +} + +} // end namespace mlir + +#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt b/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ce3168c55800dc026456021b4ae6770e7a23493 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_dialect(VectorOps VectorOps) + +set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td) +mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRVectorTransformPatternsIncGen) diff --git a/mlir/include/mlir/Dialect/VectorOps/Utils.h b/mlir/include/mlir/Dialect/VectorOps/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..5f19f849e3fe55224b8761a04b81e65636d4ae4c --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/Utils.h @@ -0,0 +1,134 @@ +//===- Utils.h - VectorOps Utils ----------------------------*- C++ -*-=======// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOROPS_UTILS_H_ +#define MLIR_DIALECT_VECTOROPS_UTILS_H_ + +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/DenseMap.h" + +namespace mlir { + +class AffineApplyOp; +class AffineForOp; +class AffineMap; +class Location; +class MemRefType; +class OpBuilder; +class Operation; +class Value; +class VectorType; + +/// Computes and returns the multi-dimensional ratio of `superShape` to +/// `subShape`. This is calculated by performing a traversal from minor to major +/// dimensions (i.e. in reverse shape order). If integral division is not +/// possible, returns None. +/// The ArrayRefs are assumed (and enforced) to only contain > 1 values. +/// This constraint comes from the fact that they are meant to be used with +/// VectorTypes, for which the property holds by construction. +/// +/// Examples: +/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4} +/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None +/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16} +Optional> shapeRatio(ArrayRef superShape, + ArrayRef subShape); + +/// Computes and returns the multi-dimensional ratio of the shapes of +/// `superVector` to `subVector`. If integral division is not possible, returns +/// None. +/// Assumes and enforces that the VectorTypes have the same elemental type. +Optional> shapeRatio(VectorType superVectorType, + VectorType subVectorType); + +/// Constructs a permutation map of invariant memref indices to vector +/// dimension. +/// +/// If no index is found to be invariant, 0 is added to the permutation_map and +/// corresponds to a vector broadcast along that dimension. +/// +/// The implementation uses the knowledge of the mapping of loops to +/// vector dimension. `loopToVectorDim` carries this information as a map with: +/// - keys representing "vectorized enclosing loops"; +/// - values representing the corresponding vector dimension. +/// Note that loopToVectorDim is a whole function map from which only enclosing +/// loop information is extracted. +/// +/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at +/// most one invariant index along each AffineForOp of `loopToVectorDim`). +/// +/// Example 1: +/// The following MLIR snippet: +/// +/// ```mlir +/// affine.for %i3 = 0 to %0 { +/// affine.for %i4 = 0 to %1 { +/// affine.for %i5 = 0 to %2 { +/// %a5 = load %arg0[%i4, %i5, %i3] : memref +/// }}} +/// ``` +/// +/// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into: +/// +/// ```mlir +/// affine.for %i3 = 0 to %0 step 32 { +/// affine.for %i4 = 0 to %1 { +/// affine.for %i5 = 0 to %2 step 256 { +/// %4 = vector.transfer_read %arg0, %i4, %i5, %i3 +/// {permutation_map: (d0, d1, d2) -> (d2, d1)} : +/// (memref, index, index) -> vector<32x256xf32> +/// }}} +/// ``` +/// +/// Meaning that vector.transfer_read will be responsible for reading the slice: +/// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>. +/// +/// Example 2: +/// The following MLIR snippet: +/// +/// ```mlir +/// %cst0 = constant 0 : index +/// affine.for %i0 = 0 to %0 { +/// %a0 = load %arg0[%cst0, %cst0] : memref +/// } +/// ``` +/// +/// may vectorize with {permutation_map: (d0) -> (0)} into: +/// +/// ```mlir +/// affine.for %i0 = 0 to %0 step 128 { +/// %3 = vector.transfer_read %arg0, %c0_0, %c0_0 +/// {permutation_map: (d0, d1) -> (0)} : +/// (memref, index, index) -> vector<128xf32> +/// } +/// ```` +/// +/// Meaning that vector.transfer_read will be responsible of reading the slice +/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. +/// +AffineMap +makePermutationMap(Operation *op, ArrayRef indices, + const DenseMap &loopToVectorDim); + +namespace matcher { + +/// Matches vector.transfer_read, vector.transfer_write and ops that return a +/// vector type that is a multiple of the sub-vector type. This allows passing +/// over other smaller vector types in the function and avoids interfering with +/// operations on those. +/// This is a first approximation, it can easily be extended in the future. +/// TODO(ntv): this could all be much simpler if we added a bit that a vector +/// type to mark that a vector is a strict super-vector but it still does not +/// warrant adding even 1 extra bit in the IR for now. +bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType); + +} // end namespace matcher +} // end namespace mlir + +#endif // MLIR_DIALECT_VECTOROPS_UTILS_H_ diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h new file mode 100644 index 0000000000000000000000000000000000000000..7234d46b765669f744bfa2ac0caece671a7ed018 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -0,0 +1,59 @@ +//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOROPS_VECTOROPS_H +#define MLIR_DIALECT_VECTOROPS_VECTOROPS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +class MLIRContext; +class OwningRewritePatternList; +namespace vector { + +/// Dialect for Ops on higher-dimensional vector types. +class VectorOpsDialect : public Dialect { +public: + VectorOpsDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "vector"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +/// Collect a set of vector-to-vector canonicalization patterns. +void populateVectorToVectorCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context); + +/// Collect a set of vector-to-vector transformation patterns. +void populateVectorToVectorTransformationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context); + +/// Returns the integer type required for subscripts in the vector dialect. +IntegerType getVectorSubscriptType(Builder &builder); + +/// Returns an integer array attribute containing the given values using +/// the integer type required for subscripts in the vector dialect. +ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef values); + +#define GET_OP_CLASSES +#include "mlir/Dialect/VectorOps/VectorOps.h.inc" + +} // end namespace vector +} // end namespace mlir + +#endif // MLIR_DIALECT_VECTOROPS_VECTOROPS_H diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td new file mode 100644 index 0000000000000000000000000000000000000000..8726b162fd6169fda6f1469781c0050e8545da5a --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -0,0 +1,1152 @@ +//===- VectorOps.td - Vector op definitions ---------------*- tablegen -*-====// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines MLIR vector operations. +// +//===----------------------------------------------------------------------===// + +#ifndef VECTOR_OPS +#define VECTOR_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/AffineOps/AffineOpsBase.td" + +def Vector_Dialect : Dialect { + let name = "vector"; + let cppNamespace = "vector"; +} + +// Base class for Vector dialect ops. +class Vector_Op traits = []> : + Op { + // For every vector op, there needs to be a: + // * void print(OpAsmPrinter &p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, + // OperationState &result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +// TODO(andydavis, ntv) Add an attribute to specify a different algebra +// with operators other than the current set: {*, +}. +def Vector_ContractionOp : + Vector_Op<"contract", [NoSideEffect]>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc, + Variadic>:$masks, + AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, + Results<(outs AnyVector)> { + let summary = "vector contraction operation"; + let description = [{ + Computes the sum of products of vector elements along contracting + dimension pairs from 2 vectors of rank M and N respectively, adds this + intermediate result to the accumulator argument of rank K, and returns a + vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims + + num_batch_dims (see dimension type descriptions below)). + + Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp) + specify the dynamic dimension sizes of valid data within the lhs/rhs vector + arguments. + + An iterator type attribute list must be specified, where each element of + the list represents an iterator with one of the following types: + + *) "reduction": reduction dimensions are present in the lhs and rhs + arguments but not in the output (or optional accumulator + argument). These are the dimensions along which the vector + contraction op computes the sum of products, and + contracting dimension pair dimension sizes must match + between lhs/rhs. + *) "parallel": Batch dimensions are iterator type "parallel", and + are non-contracting dimensions present in the lhs, rhs and + output. The lhs/rhs co-iterate along the batch dimensions, + which should be expressed in their indexing maps. + + Free dimensions are iterator type "parallel", and are + non-contraction, non-batch dimensions accessed by either the + lhs or rhs (but not both). The lhs and rhs free dimensions + are unrelated to each other and do not co-iterate, which + should be expressed in their indexing maps. + + An indexing map attribute list must be specified with an entry for lhs, rhs + and acc arguments. An indexing map attribute specifies a mapping from each + iterator in the iterator type list, to each dimension of an N-D vector. + + Examples: + + // 2D vector contraction with one contracting dimension (matmul). + #contraction_accesses = [ + (i, j, k) -> (i, k), + (i, j, k) -> (k, j), + (i, j, k) -> (i, j) + ] + #contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = [parallel, parallel, reduction] + } + + %3 = vector.contract #contraction_trait %0, %1, %2 + : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32> + + // 4D to 3D vector contraction with two contracting dimensions and + // one batch dimension. + #contraction_accesses = [ + (b0, f0, f1, c0, c1) -> (c0, b0, c1, f0), + (b0, f0, f1, c0, c1) -> (b0, c1, c0, f1), + (b0, f0, f1, c0, c1) -> (b0, f0, f1) + ] + #contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = [parallel, parallel, parallel reduction, reduction] + } + + %4 = vector.contract #contraction_trait %0, %1, %2 + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> + + // 4D vector contraction with two contracting dimensions and optional + // vector mask arguments. + %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> + %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> + + %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask + : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value lhs, Value rhs, " + "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; + let extraClassDeclaration = [{ + VectorType getLhsType() { + return lhs()->getType().cast(); + } + VectorType getRhsType() { + return rhs()->getType().cast(); + } + VectorType getAccType() { + return acc()->getType().cast(); + } + VectorType getLHSVectorMaskType() { + if (llvm::size(masks()) != 2) return VectorType(); + return getOperand(3)->getType().cast(); + } + VectorType getRHSVectorMaskType() { + if (llvm::size(masks()) != 2) return VectorType(); + return getOperand(4)->getType().cast(); + } + VectorType getResultType() { + return getResult()->getType().cast(); + } + ArrayRef getTraitAttrNames(); + SmallVector getIndexingMaps(); + static unsigned getAccOperandIndex() { return 2; } + + // Returns the bounds of each dimension in the iteration space spanned + // by the iterator types of this operation. + void getIterationBounds(SmallVectorImpl &iterationBounds); + + // Returns a list of index maps, where there is a list entry for each + // op indexing map attribute (i.e. one for each input and output, with + // the output listed last). Each index map, maps from this operations + // iteration space, to vector dimensions of the maps input/output. + void getIterationIndexMap( + std::vector> &iterationIndexMap); + + std::vector> getContractingDimMap(); + std::vector> getBatchDimMap(); + }]; +} + +def Vector_BroadcastOp : + Vector_Op<"broadcast", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyType:$source)>, + Results<(outs AnyVector:$vector)> { + let summary = "broadcast operation"; + let description = [{ + Broadcasts the scalar or k-D vector value in the source operand + to a n-D result vector such that the broadcast makes sense, i.e., + the source operand is duplicated to match the given rank and sizes + in the result vector. The legality rules are: + * the source operand must have the same element type as the result type + * a k-D vector can be broadcast to + a n-D vector if + * k <= n, and + * the sizes in the trailing dimensions n-k < i <= n with j=i+k-n + match exactly as s_j = t_i or s_j = 1: + ``` + t_1 x .. t_n-k x t_n-k+1 x .. x t_i x .. x t_n + s_1 x .. x s_j x .. x s_k + + ``` + The source operand is duplicated over all the missing leading dimensions + and stretched over the trailing dimensions where the source has a non-equal + dimension of 1. These rules imply that any scalar broadcast (k=0) to any + shaped vector with the same element type is always legal. + + Examples: + ``` + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<16xf32> + %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32> + ``` + }]; + let extraClassDeclaration = [{ + Type getSourceType() { return source()->getType(); } + VectorType getVectorType() { + return vector()->getType().cast(); + } + }]; +} + +def Vector_ShuffleOp : + Vector_Op<"shuffle", [NoSideEffect, + PredOpTrait<"first operand v1 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"second operand v2 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>, + Results<(outs AnyVector:$vector)> { + let summary = "shuffle operation"; + let description = [{ + The shuffle operation constructs a permutation (or duplication) of elements + from two input vectors, returning a vector with the same element type as + the input and a length that is the same as the shuffle mask. The two input + vectors must have the same element type, rank, and trailing dimension sizes + and shuffles their values in the leading dimension (which may differ in size) + according to the given mask. The legality rules are: + * the two operands must have the same element type as the result + * the two operands and the result must have the same rank and trailing + dimension sizes, viz. given two k-D operands + v1 : and + v2 : + we have s_i = t_i for all 1 < i <= k + * the mask length equals the leading dimension size of the result + * numbering the input vector indices left to right accross the operands, all + mask values must be within range, viz. given two k-D operands v1 and v2 + above, all mask values are in the range [0,s_1+t_1) + + Examples: + ``` + %0 = vector.shuffle %a, %b[0, 3] + : vector<2xf32>, vector<2xf32> ; yields vector<2xf32> + %1 = vector.shuffle %c, %b[0, 1, 2] + : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32> + %2 = vector.shuffle %a, %b[3, 2, 1, 0] + : vector<2xf32>, vector<2xf32> ; yields vector<4xf32> + + ``` + }]; + let builders = [OpBuilder<"Builder *builder, OperationState &result," + "Value v1, Value v2, ArrayRef">]; + let extraClassDeclaration = [{ + static StringRef getMaskAttrName() { return "mask"; } + VectorType getV1VectorType() { + return v1()->getType().cast(); + } + VectorType getV2VectorType() { + return v2()->getType().cast(); + } + VectorType getVectorType() { + return vector()->getType().cast(); + } + }]; +} + +def Vector_ExtractElementOp : + Vector_Op<"extractelement", [NoSideEffect, + PredOpTrait<"operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyVector:$vector, AnyInteger:$position)>, + Results<(outs AnyType)> { + let summary = "extractelement operation"; + let description = [{ + Takes an 1-D vector and a dynamic index position and extracts the + scalar at that position. Note that this instruction resembles + vector.extract, but is restricted to 1-D vectors and relaxed + to dynamic indices. It is meant to be closer to LLVM's version: + https://llvm.org/docs/LangRef.html#extractelement-instruction + + Example: + ``` + %c = constant 15 : i32 + %1 = vector.extractelement %0[%c : i32]: vector<16xf32> + ``` + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector()->getType().cast(); + } + }]; +} + +def Vector_ExtractOp : + Vector_Op<"extract", [NoSideEffect, + PredOpTrait<"operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>, + Results<(outs AnyType)> { + let summary = "extract operation"; + let description = [{ + Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at + the proper position. Degenerates to an element type in the 0-D case. + + Examples: + ``` + %1 = vector.extract %0[3]: vector<4x8x16xf32> + %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32> + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value source," + "ArrayRef">]; + let extraClassDeclaration = [{ + static StringRef getPositionAttrName() { return "position"; } + VectorType getVectorType() { + return vector()->getType().cast(); + } + }]; +} + +def Vector_ExtractSlicesOp : + Vector_Op<"extract_slices", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, I64ArrayAttr:$sizes, + I64ArrayAttr:$strides)>, + Results<(outs TupleOf<[AnyVector]>)> { + let summary = "vector extract slices operation"; + let description = [{ + Takes an N-d vector and returns a tuple of vector slices of 'vector', + based on 'sizes' and 'strides' parameters. + + The arguments 'sizes' and 'strides' represent a specification for + generating the unrolling of 'vector' shape, which has all slices of shape + 'sizes' except for slices at dimension boundaries when 'vector' dimension + sizes are not a multiple of 'sizes'. + + Each slice is returned at the tuple element index corresponding to the + linear index of the slice w.r.t the unrolling scheme represented by 'sizes'. + Currently, only unit strides are supported. + + Examples: + ``` + %0 = vector.transfer_read ...: vector<4x2xf32> + + %1 = vector.extract_slices %0, [2, 2], [1, 1] + : vector<4x2xf32> into tuple, vector<2x2xf32>> + + // Example with partial slices at dimension boundaries. + %2 = vector.transfer_read ...: vector<4x3xf32> + + %3 = vector.extract_slices %2, [2, 2], [1, 1] + : vector<4x3xf32> into tuple, vector<2x1xf32>, + vector<2x2xf32>, vector<2x1xf32>> + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, TupleType tupleType, " # + "Value vector, ArrayRef sizes, " # + "ArrayRef strides">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector()->getType().cast(); + } + TupleType getResultTupleType() { + return getResult()->getType().cast(); + } + void getSizes(SmallVectorImpl &results); + void getStrides(SmallVectorImpl &results); + static StringRef getSizesAttrName() { return "sizes"; } + static StringRef getStridesAttrName() { return "strides"; } + }]; +} + +def Vector_InsertElementOp : + Vector_Op<"insertelement", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$source, AnyVector:$dest, AnyInteger:$position)>, + Results<(outs AnyVector)> { + let summary = "insertelement operation"; + let description = [{ + Takes a scalar source, an 1-D destination vector and a dynamic index + position and inserts the source into the destination at the proper + position. Note that this instruction resembles vector.insert, but + is restricted to 1-D vectors and relaxed to dynamic indices. It is + meant to be closer to LLVM's version: + https://llvm.org/docs/LangRef.html#insertelement-instruction + + Example: + ``` + %c = constant 15 : i32 + %f = constant 0.0f : f32 + %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> + ``` + }]; + let extraClassDeclaration = [{ + Type getSourceType() { return source()->getType(); } + VectorType getDestVectorType() { + return dest()->getType().cast(); + } + }]; +} + +def Vector_InsertOp : + Vector_Op<"insert", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>, + Results<(outs AnyVector)> { + let summary = "insert operation"; + let description = [{ + Takes an n-D source vector, an (n+k)-D destination vector and a k-D position + and inserts the n-D source into the (n+k)-D destination at the proper + position. Degenerates to a scalar source type when n = 0. + + Examples: + ``` + %2 = vector.insert %0, %1[3]: + vector<8x16xf32> into vector<4x8x16xf32> + %5 = vector.insert %3, %4[3, 3, 3]: + f32 into vector<4x8x16xf32> + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value source, " # + "Value dest, ArrayRef">]; + let extraClassDeclaration = [{ + static StringRef getPositionAttrName() { return "position"; } + Type getSourceType() { return source()->getType(); } + VectorType getDestVectorType() { + return dest()->getType().cast(); + } + }]; +} + +def Vector_InsertSlicesOp : + Vector_Op<"insert_slices", [NoSideEffect]>, + Arguments<(ins TupleOf<[AnyVector]>:$vectors, I64ArrayAttr:$sizes, + I64ArrayAttr:$strides)>, + Results<(outs AnyVector)> { + let summary = "vector insert slices operation"; + let description = [{ + Takes a tuple of vector slices and inserts them into the vector result + according to the 'sizes' and 'strides' parameters. + + The arguments 'sizes' and 'strides' represent a specification for + generating the unrolling of 'vector' shape, which has all slices of shape + 'sizes' except for slices at dimension boundaries when 'vector' dimension + sizes are not a multiple of 'sizes'. + + Each slice in 'vectors' is at the tuple element index corresponding to the + linear index of the slice w.r.t the unrolling scheme represented by 'sizes'. + Currently, only unit strides are supported. + + Examples: + ``` + %0 = vector.extract_slices %0, [2, 2], [1, 1] + : vector<4x2xf32> into tuple, vector<2x2xf32>> + + %1 = vector.insert_slices %0, [2, 2], [1, 1] + : tuple, vector<2x2xf32>> into vector<4x2xf32> + + // Example with partial slices at dimension boundaries. + %3 = vector.extract_slices %2, [2, 2], [1, 1] + : vector<4x3xf32> into tuple, vector<2x1xf32>, + vector<2x2xf32>, vector<2x1xf32>> + + %4 = vector.insert_slices %3, [2, 2], [1, 1] + : tuple, vector<2x1xf32>, + vector<2x2xf32>, vector<2x1xf32>> into vector<4x3xf32> + ``` + }]; + + let extraClassDeclaration = [{ + TupleType getSourceTupleType() { + return vectors()->getType().cast(); + } + VectorType getResultVectorType() { + return getResult()->getType().cast(); + } + void getSizes(SmallVectorImpl &results); + void getStrides(SmallVectorImpl &results); + static StringRef getSizesAttrName() { return "sizes"; } + static StringRef getStridesAttrName() { return "strides"; } + }]; +} + +def Vector_InsertStridedSliceOp : + Vector_Op<"insert_strided_slice", [NoSideEffect, + PredOpTrait<"operand #0 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets, + I64ArrayAttr:$strides)>, + Results<(outs AnyVector)> { + let summary = "strided_slice operation"; + let description = [{ + Takes a k-D source vector, an n-D destination vector (n >= k), n-D `offsets` + integer array attribute, a k-D `strides` integer array attribute and inserts + the k-D source vector as a strided subvector at the proper offset into the + n-D destination vector. + + At the moment strides must contain only 1s. + + Returns an n-D vector that is a copy of the n-D destination vector in which + the last k-D dimensions contain the k-D source vector elements strided at + the proper location as specified by the offsets. + + Examples: + ``` + %2 = vector.insert_strided_slice %0, %1 + {offsets : [0, 0, 2], strides : [1, 1]}: + vector<2x4xf32> into vector<16x4x8xf32> + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value source, Value dest, " # + "ArrayRef offsets, ArrayRef strides">]; + let extraClassDeclaration = [{ + static StringRef getOffsetsAttrName() { return "offsets"; } + static StringRef getStridesAttrName() { return "strides"; } + VectorType getSourceVectorType() { + return source()->getType().cast(); + } + VectorType getDestVectorType() { + return dest()->getType().cast(); + } + }]; +} + +def Vector_OuterProductOp : + Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic:$acc)>, + Results<(outs AnyVector)> { + let summary = "vector outerproduct with optional fused add"; + let description = [{ + Takes 2 1-D vectors and returns the 2-D vector containing the outer product. + + An optional extra 2-D vector argument may be specified in which case the + operation returns the sum of the outer product and the extra vector. When + lowered to the LLVMIR dialect, this form emits `llvm.intr.fmuladd`, which + can lower to actual `fma` instructions in LLVM. + + Examples + + %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32> + return %2: vector<4x8xf32> + + %3 = vector.outerproduct %0, %1, %2: + vector<4xf32>, vector<8xf32>, vector<4x8xf32> + return %3: vector<4x8xf32> + }]; + let extraClassDeclaration = [{ + VectorType getOperandVectorTypeLHS() { + return lhs()->getType().cast(); + } + VectorType getOperandVectorTypeRHS() { + return rhs()->getType().cast(); + } + VectorType getOperandVectorTypeACC() { + return (llvm::size(acc()) == 0) ? VectorType() : + (*acc().begin())->getType().cast(); + } + VectorType getVectorType() { + return getResult()->getType().cast(); + } + }]; +} + +// TODO(andydavis) Add transformation which decomposes ReshapeOp into an +// optimized sequence of vector rotate/shuffle/select operations. +def Vector_ReshapeOp : + Vector_Op<"reshape", [AttrSizedOperandSegments, NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Variadic:$input_shape, + Variadic:$output_shape, + I64ArrayAttr:$fixed_vector_sizes, + I32ElementsAttr:$operand_segment_sizes)>, + Results<(outs AnyVector)> { + let summary = "vector reshape operation"; + let description = [{ + Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining + fixed vector dimension 'fixed_vector_sizes' on the innermost vector + dimensions. + + The parameters 'input_shape' and 'output_shape' represent valid data shapes + across fixed vector shapes. For example, if a vector has a valid data + shape [6] with fixed vector size [8], then the valid data elements are + assumed to be stored at the beginning of the vector with the remaining + vector elements undefined. + + In the examples below, valid data elements are represented by an alphabetic + character, and undefined data elements are represented by '-'. + + Example + + vector<1x8xf32> with valid data shape [6], fixed vector sizes [8] + + input: [a, b, c, d, e, f] + + layout map: (d0) -> (d0 floordiv 8, d0 mod 8) + + vector layout: [a, b, c, d, e, f, -, -] + + Example + + vector<2x8xf32> with valid data shape [10], fixed vector sizes [8] + + input: [a, b, c, d, e, f, g, h, i, j] + + layout map: (d0) -> (d0 floordiv 8, d0 mod 8) + + vector layout: [[a, b, c, d, e, f, g, h], + [i, j, -, -, -, -, -, -]] + + Example + + vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes + [2, 3] + + input: [[a, b, c, d, e], + [f, g, h, i, j], + [k, l, m, n, o]] + + layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5, + d0 mod 3, d1 mod 5) + + vector layout: [[[[a, b, c], + [f, g, h]] + [[d, e, -], + [i, j, -]]], + [[[k, l, m], + [-, -, -]] + [[n, o, -], + [-, -, -]]]] + + Example + + %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4] + : vector<3x2x4xf32> to vector<2x3x4xf32> + + input: [[a, b, c, d, e, f], + [g, h, i, j, k, l], + [m, n, o, p, q, r]] + + layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4) + + + Input vector: [[[a, b, c, d], + [e, f, -, -]], + [[g, h, i, j], + [k, l, -, -]], + [[m, n, o, p], + [q, r, -, -]]] + + Output vector: [[[a, b, c, d], + [e, f, g, h], + [i, -, -, -]], + [[j, k, l, m], + [n, o, p, q], + [r, -, -, -]]] + }]; + + let extraClassDeclaration = [{ + VectorType getInputVectorType() { + return vector()->getType().cast(); + } + VectorType getOutputVectorType() { + return getResult()->getType().cast(); + } + + /// Returns as integer value the number of input shape operands. + int64_t getNumInputShapeSizes() { return input_shape().size(); } + + /// Returns as integer value the number of output shape operands. + int64_t getNumOutputShapeSizes() { return output_shape().size(); } + + void getFixedVectorSizes(SmallVectorImpl &results); + + static StringRef getFixedVectorSizesAttrName() { + return "fixed_vector_sizes"; + } + static StringRef getInputShapeAttrName() { return "input_shape"; } + static StringRef getOutputShapeAttrName() { return "output_shape"; } + }]; +} + +def Vector_StridedSliceOp : + Vector_Op<"strided_slice", [NoSideEffect, + PredOpTrait<"operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets, + I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>, + Results<(outs AnyVector)> { + let summary = "strided_slice operation"; + let description = [{ + Takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes` + integer array attribute, a k-D `strides` integer array attribute and + extracts the n-D subvector at the proper offset. + + At the moment strides must contain only 1s. + // TODO(ntv) support non-1 strides. + + Returns an n-D vector where the first k-D dimensions match the `sizes` + attribute. The returned subvector contains the elements starting at offset + `offsets` and ending at `offsets + sizes`. + + Examples: + ``` + %1 = vector.strided_slice %0 + {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}: + vector<4x8x16xf32> to vector<2x4x16xf32> + ``` + + // TODO(ntv) Evolve to a range form syntax similar to: + %1 = vector.strided_slice %0[0:2:1][2:4:1] + vector<4x8x16xf32> to vector<2x4x16xf32> + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value source, " # + "ArrayRef offsets, ArrayRef sizes, " # + "ArrayRef strides">]; + let extraClassDeclaration = [{ + static StringRef getOffsetsAttrName() { return "offsets"; } + static StringRef getSizesAttrName() { return "sizes"; } + static StringRef getStridesAttrName() { return "strides"; } + VectorType getVectorType(){ return vector()->getType().cast(); } + void getOffsets(SmallVectorImpl &results); + }]; + let hasCanonicalizer = 1; +} + +def Vector_TransferReadOp : + Vector_Op<"transfer_read">, + Arguments<(ins AnyMemRef:$memref, Variadic:$indices, + AffineMapAttr:$permutation_map, AnyType:$padding)>, + Results<(outs AnyVector:$vector)> { + + let summary = "Reads a supervector from memory into an SSA vector value."; + + let description = [{ + The `vector.transfer_read` op performs a blocking read from a slice within + a [MemRef](../LangRef.md#memref-type) supplied as its first operand + into a [vector](../LangRef.md#vector-type) of the same base elemental type. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map + [attribute](../LangRef.md#attributes) is an + [affine-map](Affine.md#affine-maps) which specifies the transposition on the + slice to match the vector shape. The size of the slice is specified by the + size of the vector, given as the return type. An `ssa-value` of the same + elemental type as the MemRef is provided as the last operand to specify + padding in the case of out-of-bounds accesses. This operation is called + 'read' by opposition to 'load' because the super-vector granularity is + generally not representable with a single hardware register. + A `vector.transfer_read` is thus a mid-level + abstraction that supports super-vectorization with non-effecting padding for + full-tile-only code. + + More precisely, let's dive deeper into the permutation_map for the following + MLIR: + + ```mlir + vector.transfer_read %A[%expr1, %expr2, %expr3, %expr4] + { permutation_map : (d0,d1,d2,d3) -> (d2,0,d0) } : + memref, vector<3x4x5xf32> + ``` + + This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3, + %expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice + is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` + + That slice needs to be read into a `vector<3x4x5xf32>`. Since the + permutation map is not full rank, there must be a broadcast along vector + dimension `1`. + + A notional lowering of vector.transfer_read could generate code resembling: + + ```mlir + // %expr1, %expr2, %expr3, %expr4 defined before this point + %tmp = alloc() : vector<3x4x5xf32> + %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + for %i = 0 to 3 { + affine.for %j = 0 to 4 { + affine.for %k = 0 to 5 { + %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : + memref + store %tmp[%i, %j, %k] : vector<3x4x5xf32> + }}} + %c0 = constant 0 : index + %vec = load %view_in_tmp[%c0] : vector<3x4x5xf32> + ``` + + On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that + the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are + actually transferred between `%A` and `%tmp`. + + Alternatively, if a notional vector broadcast operation were available, the + lowered code would resemble: + + ```mlir + // %expr1, %expr2, %expr3, %expr4 defined before this point + %tmp = alloc() : vector<3x4x5xf32> + %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + for %i = 0 to 3 { + affine.for %k = 0 to 5 { + %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : + memref + store %tmp[%i, 0, %k] : vector<3x4x5xf32> + }} + %c0 = constant 0 : index + %tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32> + %vec = broadcast %tmpvec, 1 : vector<3x4x5xf32> + ``` + + where `broadcast` broadcasts from element 0 to all others along the + specified dimension. This time, the temporary storage footprint is `3 * 5` + values which is the same amount of data as the `3 * 5` values transferred. + An additional `1` broadcast is required. On a GPU this broadcast could be + implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`. + + Syntax + ``` + operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list + `{` attribute-entry `} :` memref-type `,` vector-type + ``` + + Examples: + + ```mlir + // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> + // and pad with %f0 to handle the boundary case: + %f0 = constant 0.0f : f32 + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 256 { + affine.for %i2 = 0 to %2 step 32 { + %v = vector.transfer_read %A[%i0, %i1, %i2], (%f0) + {permutation_map: (d0, d1, d2) -> (d2, d1)} : + memref, vector<32x256xf32> + }}} + + // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into + // vector<128xf32>. The underlying implementation will require a 1-D vector + // broadcast: + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + %3 = vector.transfer_read %A[%i0, %i1] + {permutation_map: (d0, d1) -> (0)} : + memref, vector<128xf32> + } + } + + // Read from a memref with vector element type. + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 + {permutation_map = (d0, d1)->(d0, d1)} + : memref>, vector<1x1x4x3xf32> + ``` + }]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref()->getType().cast(); + } + VectorType getVectorType() { + return vector()->getType().cast(); + } + }]; +} + +def Vector_TransferWriteOp : + Vector_Op<"transfer_write">, + Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, + Variadic:$indices, + AffineMapAttr:$permutation_map)> { + + let summary = "The vector.transfer_write op writes a supervector to memory."; + + let description = [{ + The `vector.transfer_write` performs a blocking write from a + [vector](../LangRef.md#vector-type), supplied as its first operand, into a + slice within a [MemRef](../LangRef.md#memref-type) of the same base + elemental type, supplied as its second operand. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `3 .. 2 + rank(memref)`. + The permutation_map [attribute](../LangRef.md#attributes) is an + [affine-map](Affine.md#affine-maps) which specifies the transposition on the + slice to match the vector shape. The size of the slice is specified by the + size of the vector. This operation is called 'write' by opposition to + 'store' because the super-vector granularity is generally not representable + with a single hardware register. A `vector.transfer_write` is thus a + mid-level abstraction that supports super-vectorization with non-effecting + padding for full-tile-only code. It is the responsibility of + `vector.transfer_write`'s implementation to ensure the memory writes are + valid. Different lowerings may be pertinent depending on the hardware + support. + + Syntax: + + ``` + operation ::= `vector.transfer_write` ssa-use-list `{` attribute-entry `} : + ` vector-type ', ' memref-type ' + ``` + + Examples: + + ```mlir + // write vector<16x32x64xf32> into the slice + // `%A[%i0, %i1:%i1+32, %i2:%i2+64, %i3:%i3+16]`: + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 32 { + affine.for %i2 = 0 to %2 step 64 { + affine.for %i3 = 0 to %3 step 16 { + %val = `ssa-value` : vector<16x32x64xf32> + vector.transfer_write %val, %A[%i0, %i1, %i2, %i3] + {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : + vector<16x32x64xf32>, memref + }}}} + + // write to a memref with vector element type. + vector.transfer_write %4, %arg1[%c3, %c3] + {permutation_map = (d0, d1)->(d0, d1)} + : vector<1x1x4x3xf32>, memref> + ``` + }]; + + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector()->getType().cast(); + } + MemRefType getMemRefType() { + return memref()->getType().cast(); + } + }]; +} + +def Vector_TypeCastOp : + Vector_Op<"type_cast", [NoSideEffect]>, + Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>, + Results<(outs AnyMemRef)> { + let summary = "type_cast op converts a scalar memref to a vector memref"; + let description = [{ + Performs a conversion from a memref with scalar element to a memref with a + *single* vector element, copying the shape of the memref to the vector. This + is the minimal viable operation that is required to makeke + super-vectorization operational. It can be seen as a special case of the + `view` operation but scoped in the super-vectorization context. + + Syntax: + + ``` + operation ::= `vector.type_cast` ssa-use : memref-type to memref-type + ``` + + Example: + + ```mlir + %A = alloc() : memref<5x4x3xf32> + %VA = vector.type_cast %A : memref<5x4x3xf32> to memref> + ``` + }]; + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value source">]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref()->getType().cast(); + } + MemRefType getResultMemRefType() { + return getResult()->getType().cast(); + } + }]; +} + +def Vector_ConstantMaskOp : + Vector_Op<"constant_mask", [NoSideEffect]>, + Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>, + Results<(outs VectorOf<[I1]>)> { + let summary = "creates a constant vector mask"; + let description = [{ + Creates and returns a vector mask where elements of the result vector + are set to '0' or '1', based on whether the element indices are contained + within a hyper-rectangular region specified by the 'mask_dim_sizes' + array attribute argument. Each element of the 'mask_dim_sizes' array, + specifies an exclusive upper bound [0, mask-dim-size-element-value) + for a unique dimension in the vector result. The conjunction of the ranges + define a hyper-rectangular region within which elements values are set to 1 + (otherwise element values are set to 0). + + Example: create a constant vector mask of size 4x3xi1 with elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + + %1 = vector.constant_mask [3, 2] : vector<4x3xi1> + + print %1 + columns + 0 1 2 + |------------ + 0 | 1 1 0 + rows 1 | 1 1 0 + 2 | 1 1 0 + 3 | 0 0 0 + }]; + + let extraClassDeclaration = [{ + static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } + }]; +} + +def Vector_CreateMaskOp : + Vector_Op<"create_mask", [NoSideEffect]>, + Arguments<(ins Variadic:$operands)>, Results<(outs VectorOf<[I1]>)> { + let summary = "creates a vector mask"; + let description = [{ + Creates and returns a vector mask where elements of the result vector + are set to '0' or '1', based on whether the element indices are contained + within a hyper-rectangular region specified by the operands. Specifically, + each operand specifies a range [0, operand-value) for a unique dimension in + the vector result. The conjunction of the operand ranges define a + hyper-rectangular region within which elements values are set to 1 + (otherwise element values are set to 0). + + Example: create a vector mask of size 4x3xi1 where elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + + %1 = vector.create_mask %c3, %c2 : vector<4x3xi1> + + print %1 + columns + 0 1 2 + |------------ + 0 | 1 1 0 + rows 1 | 1 1 0 + 2 | 1 1 0 + 3 | 0 0 0 + }]; + + let hasCanonicalizer = 1; +} + +def Vector_TupleOp : + Vector_Op<"tuple", [NoSideEffect]>, + Arguments<(ins Variadic:$vectors)>, + Results<(outs TupleOf<[AnyVector]>)> { + let summary = "make tuple of vectors operation"; + let description = [{ + Returns a tuple of its operands 'vectors'. + + Note that this operation is used during the vector op unrolling + transformation and should be removed before lowering to lower-level + dialects. + + + Examples: + ``` + %0 = vector.transfer_read ... : vector<2x2xf32> + %1 = vector.transfer_read ... : vector<2x1xf32> + %2 = vector.transfer_read ... : vector<2x2xf32> + %3 = vector.transfer_read ... : vector<2x1xf32> + + %4 = vector.tuple %0, %1, %2, %3 + : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32> + + ``` + }]; + + let extraClassDeclaration = [{ + TupleType getResultTupleType() { + return getResult()->getType().cast(); + } + }]; +} + +def Vector_TupleGetOp : + Vector_Op<"tuple_get", [NoSideEffect]>, + Arguments<(ins TupleOf<[AnyVector]>:$vectors, APIntAttr:$index)>, + Results<(outs AnyVector)> { + let summary = "vector tuple get operation"; + let description = [{ + Returns the tuple element of 'vectors' at 'index'. + + Note that this operation is used during the vector op unrolling + transformation and should be removed before lowering to lower-level + dialects. + + Examples: + ``` + %4 = vector.tuple %0, %1, %2, %3 + : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>> + + %5 = vector.tuple_get %4, 1 + : tuple, vector<2x1xf32>, + vector<2x2xf32>, vector<2x1xf32>> + ``` + }]; + + let extraClassDeclaration = [{ + VectorType getResultVectorType() { + return getResult()->getType().cast(); + } + int64_t getIndex() { + return getAttrOfType("index").getValue().getSExtValue(); + } + static StringRef getIndexAttrName() { return "index"; } + }]; +} + +def Vector_PrintOp : + Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> { + let summary = "print operation (for testing and debugging)"; + let description = [{ + Prints the source vector (or scalar) to stdout in human readable + format (for testing and debugging). No return value. + + Examples: + ``` + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<4xf32> + vector.print %1 : vector<4xf32> + + when lowered to LLVM, the vector print is unrolled into + elementary printing method calls that at runtime will yield + + ( 0.0, 0.0, 0.0, 0.0 ) + + on stdout when linked with a small runtime support library, + which only needs to provide a few printing methods (single + value for all data types, opening/closing bracket, comma, + newline). + ``` + }]; + let verifier = ?; + let extraClassDeclaration = [{ + Type getPrintType() { + return source()->getType(); + } + }]; +} + +#endif // VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td new file mode 100644 index 0000000000000000000000000000000000000000..5d0244f6989537c20e9d0561457d078d7f383e89 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td @@ -0,0 +1,26 @@ +//===- VectorTransformPatterns.td - Vector-Vector patterns -*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the pattern definition file for declarative Vector transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef VECTOR_TRANSFORM_PATTERNS +#define VECTOR_TRANSFORM_PATTERNS + +include "mlir/IR/OpBase.td" + +class HasShape shape> : + CPred<"$0->getType().cast().hasStaticShape({" # + StrJoinInt.result # "})">; + +class UnrollVectorOp factors> : NativeCodeCall< + "unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " # + "{" # StrJoinInt.result # "})">; + +#endif // VECTOR_TRANSFORM_PATTERNS diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h new file mode 100644 index 0000000000000000000000000000000000000000..feb8bd60445ba921815a9dc374a3f7e2c25246c6 --- /dev/null +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -0,0 +1,73 @@ +//===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ +#define DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class MLIRContext; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the Vector dialect to itself. +/// Should be merged with populateVectorToAffineLoopsConversionPatterns. +void populateVectorToVectorConversionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + ArrayRef coarseVectorShape = {}, + ArrayRef fineVectorShape = {}); + +//////////////////////////////////////////////////////////////////////////////// +// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite +// patterns. As such, they must not call into `rewriter.erase/replace` APIs and +// it is the responsibility of the enclosing PatternRewriter to erase on +// success. +//////////////////////////////////////////////////////////////////////////////// + +namespace vector { + +// Entry point for unrolling declarative pattern rewrites. +// `op` is unrolled to the `targetShape` as follows, for each of its operands: +// 1. the unrolled type `unrolledVectorType` and number of unrolled instances +// `numUnrolledInstances` are computed from the `targetShape`. For now it is +// assumed the unrolling factors divide the vector sizes. +// 2. a fakeFork cast op is inserted that takes the operand and returns +// `numUnrolledInstances` results of type `unrolledVectorType`. +// 3. the original op is cloned `numUnrolledInstances` times, once for each +// result of the fakeFork cast op. +// 4. a fakeJoin cast op takes all these results and merges them into a single +// aggregate vector result whose size matches the original non-unrolled op +// operand types. +// +// Example: +// +// opA(operand0, operand1) // numUnrolledInstances = 3 +// +// operand0 operand1 +// | | +// fork fork +// <----------gather all fork ops ---------> +// /|\ /|\ +// f00 f01 f02 f10 f11 f12 +// <---------- clone op 3 times ---------> +// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) +// \ | / +// <-------------------- join -------------------------> +// +// Other local patterns then kick in iteratively (including DCE) and compose +// until all the fakeFork and fakeJoin ops are removed. +// +// This will be extended in the future to support more advanced use cases than +// simple pointwise ops. +Value unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, + ArrayRef targetShape); + +} // namespace vector +} // namespace mlir + +#endif // DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h new file mode 100644 index 0000000000000000000000000000000000000000..d598c1cfb23ba2e8cc204182f3a8b7654a5e780e --- /dev/null +++ b/mlir/include/mlir/EDSC/Builders.h @@ -0,0 +1,538 @@ +//===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable interfaces for building structured MLIR +// snippets in a declarative fashion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDSC_BUILDERS_H_ +#define MLIR_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/FoldUtils.h" + +namespace mlir { + +namespace edsc { + +struct index_t { + explicit index_t(int64_t v) : v(v) {} + explicit operator int64_t() { return v; } + int64_t v; +}; + +class BlockHandle; +class CapturableHandle; +class NestedBuilder; +class ValueHandle; + +/// Helper class to transparently handle builder insertion points by RAII. +/// As its name indicates, a ScopedContext is means to be used locally in a +/// scoped fashion. This abstracts away all the boilerplate related to +/// checking proper usage of captures, NestedBuilders as well as handling the +/// setting and restoring of insertion points. +class ScopedContext { +public: + ScopedContext(OpBuilder &builder, Location location); + + /// Sets the insertion point of the builder to 'newInsertPt' for the duration + /// of the scope. The existing insertion point of the builder is restored on + /// destruction. + ScopedContext(OpBuilder &builder, OpBuilder::InsertPoint newInsertPt, + Location location); + ~ScopedContext(); + + static MLIRContext *getContext(); + static OpBuilder &getBuilder(); + static Location getLocation(); + +private: + /// Only NestedBuilder (which is used to create an operation with a body) + /// may access private members in order to implement scoping. + friend class NestedBuilder; + + ScopedContext() = delete; + ScopedContext(const ScopedContext &) = delete; + ScopedContext &operator=(const ScopedContext &) = delete; + + static ScopedContext *&getCurrentScopedContext(); + + /// Top level OpBuilder. + OpBuilder &builder; + /// The previous insertion point of the builder. + Optional prevBuilderInsertPoint; + /// Current location. + Location location; + /// Parent context we return into. + ScopedContext *enclosingScopedContext; + /// Defensively keeps track of the current NestedBuilder to ensure proper + /// scoping usage. + NestedBuilder *nestedBuilder; + + // TODO: Implement scoping of ValueHandles. To do this we need a proper data + // structure to hold ValueHandle objects. We can emulate one but there should + // already be something available in LLVM for this purpose. +}; + +/// A NestedBuilder is a scoping abstraction to create an idiomatic syntax +/// embedded in C++ that serves the purpose of building nested MLIR. +/// Nesting and compositionality is obtained by using the strict ordering that +/// exists between object construction and method invocation on said object (in +/// our case, the call to `operator()`). +/// This ordering allows implementing an abstraction that decouples definition +/// from declaration (in a PL sense) on placeholders of type ValueHandle and +/// BlockHandle. +class NestedBuilder { +protected: + NestedBuilder() = default; + NestedBuilder(const NestedBuilder &) = delete; + NestedBuilder(NestedBuilder &&other) : bodyScope(other.bodyScope) { + other.bodyScope = nullptr; + } + + NestedBuilder &operator=(const NestedBuilder &) = delete; + NestedBuilder &operator=(NestedBuilder &&other) { + std::swap(bodyScope, other.bodyScope); + return *this; + } + + /// Enter an mlir::Block and setup a ScopedContext to insert operations at + /// the end of it. Since we cannot use c++ language-level scoping to implement + /// scoping itself, we use enter/exit pairs of operations. + /// As a consequence we must allocate a new OpBuilder + ScopedContext and + /// let the escape. + /// Step back "prev" times from the end of the block to set up the insertion + /// point, which is useful for non-empty blocks. + void enter(mlir::Block *block, int prev = 0) { + bodyScope = new ScopedContext( + ScopedContext::getBuilder(), + OpBuilder::InsertPoint(block, std::prev(block->end(), prev)), + ScopedContext::getLocation()); + bodyScope->nestedBuilder = this; + } + + /// Exit the current mlir::Block by explicitly deleting the dynamically + /// allocated OpBuilder and ScopedContext. + void exit() { + // Reclaim now to exit the scope. + bodyScope->nestedBuilder = nullptr; + delete bodyScope; + bodyScope = nullptr; + } + + /// Custom destructor does nothing because we already destroyed bodyScope + /// manually in `exit`. Insert an assertion to defensively guard against + /// improper usage of scoping. + ~NestedBuilder() { + assert(!bodyScope && + "Illegal use of NestedBuilder; must have called exit()"); + } + +private: + ScopedContext *bodyScope = nullptr; +}; + +/// A LoopBuilder is a generic NestedBuilder for loop-like MLIR operations. +/// More specifically it is meant to be used as a temporary object for +/// representing any nested MLIR construct that is "related to" an mlir::Value +/// (for now an induction variable). +/// This is extensible and will evolve in the future as MLIR evolves, hence +/// the name LoopBuilder (as opposed to say ForBuilder or AffineForBuilder). +class LoopBuilder : public NestedBuilder { +public: + /// Constructs a new AffineForOp and captures the associated induction + /// variable. A ValueHandle pointer is passed as the first argument and is the + /// *only* way to capture the loop induction variable. + static LoopBuilder makeAffine(ValueHandle *iv, + ArrayRef lbHandles, + ArrayRef ubHandles, int64_t step); + /// Constructs a new loop::ForOp and captures the associated induction + /// variable. A ValueHandle pointer is passed as the first argument and is the + /// *only* way to capture the loop induction variable. + static LoopBuilder makeLoop(ValueHandle *iv, ValueHandle lbHandle, + ValueHandle ubHandle, ValueHandle stepHandle); + LoopBuilder(const LoopBuilder &) = delete; + LoopBuilder(LoopBuilder &&) = default; + + LoopBuilder &operator=(const LoopBuilder &) = delete; + LoopBuilder &operator=(LoopBuilder &&) = default; + + /// The only purpose of this operator is to serve as a sequence point so that + /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is + /// scoped within a LoopBuilder. + void operator()(function_ref fun = nullptr); + +private: + LoopBuilder() = default; +}; + +/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid +/// explicitly writing all the loops in a nest. This simple functionality is +/// also useful to write rank-agnostic custom ops. +/// +/// Usage: +/// +/// ```c++ +/// AffineLoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, +/// 1})( +/// [&](){ +/// ... +/// }); +/// ``` +/// +/// ```c++ +/// AffineLoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){ +/// AffineLoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){ +/// AffineLoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){ +/// ... +/// }), +/// }), +/// }); +/// ``` +class AffineLoopNestBuilder { +public: + // This entry point accommodates the fact that AffineForOp implicitly uses + // multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max` + // and and `min` constraints respectively. + AffineLoopNestBuilder(ValueHandle *iv, ArrayRef lbs, + ArrayRef ubs, int64_t step); + AffineLoopNestBuilder(ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); + + void operator()(function_ref fun = nullptr); + +private: + SmallVector loops; +}; + +/// Helper class to sugar building loop.for loop nests from ranges. +/// This is similar to edsc::AffineLoopNestBuilder except it operates on +/// loop.for. +class LoopNestBuilder { +public: + LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); + void operator()(std::function fun = nullptr); + +private: + SmallVector loops; +}; + +// This class exists solely to handle the C++ vexing parse case when +// trying to enter a Block that has already been constructed. +class Append {}; + +/// A BlockBuilder is a NestedBuilder for mlir::Block*. +/// This exists by opposition to LoopBuilder which is not related to an +/// mlir::Block* but to a mlir::Value. +/// It is meant to be used as a temporary object for representing any nested +/// MLIR construct that is "related to" an mlir::Block*. +class BlockBuilder : public NestedBuilder { +public: + /// Enters the mlir::Block* previously captured by `bh` and sets the insertion + /// point to its end. + BlockBuilder(BlockHandle bh, Append); + + /// Constructs a new mlir::Block with argument types derived from `args`. + /// Captures the new block in `bh` and its arguments into `args`. + /// Enters the new mlir::Block* and sets the insertion point to its end. + /// + /// Prerequisites: + /// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are + /// not yet bound to mlir::Value. + BlockBuilder(BlockHandle *bh, ArrayRef args); + + /// The only purpose of this operator is to serve as a sequence point so that + /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is + /// scoped within a BlockBuilder. + void operator()(function_ref fun = nullptr); + +private: + BlockBuilder(BlockBuilder &) = delete; + BlockBuilder &operator=(BlockBuilder &other) = delete; +}; + +/// Base class for ValueHandle, OperationHandle and BlockHandle. +/// Not meant to be used outside of these classes. +class CapturableHandle { +protected: + CapturableHandle() = default; +}; + +/// ValueHandle implements a (potentially "delayed") typed Value abstraction. +/// ValueHandle should be captured by pointer but otherwise passed by Value +/// everywhere. +/// A ValueHandle can have 3 states: +/// 1. null state (empty type and empty value), in which case it does not hold +/// a value and must never hold a Value (now or in the future). This is +/// used for MLIR operations with zero returns as well as the result of +/// calling a NestedBuilder::operator(). In both cases the objective is to +/// have an object that can be inserted in an ArrayRef to +/// implement nesting; +/// 2. delayed state (empty value), in which case it represents an eagerly +/// typed "delayed" value that can be hold a Value in the future; +/// 3. constructed state,in which case it holds a Value. +/// +/// A ValueHandle is meant to capture a single Value and should be used for +/// operations that have a single result. For convenience of use, we also +/// include AffineForOp in this category although it does not return a value. +/// In the case of AffineForOp, the captured Value is the loop induction +/// variable. +class ValueHandle : public CapturableHandle { +public: + /// A ValueHandle in a null state can never be captured; + static ValueHandle null() { return ValueHandle(); } + + /// A ValueHandle that is constructed from a Type represents a typed "delayed" + /// Value. A delayed Value can only capture Values of the specified type. + /// Such a delayed value represents the declaration (in the PL sense) of a + /// placeholder for an mlir::Value that will be constructed and captured at + /// some later point in the program. + explicit ValueHandle(Type t) : t(t), v(nullptr) {} + + /// A ValueHandle that is constructed from an mlir::Value is an "eager" + /// Value. An eager Value represents both the declaration and the definition + /// (in the PL sense) of a placeholder for an mlir::Value that has already + /// been constructed in the past and that is captured "now" in the program. + explicit ValueHandle(Value v) : t(v->getType()), v(v) {} + + /// Builds a ConstantIndexOp of value `cst`. The constant is created at the + /// current insertion point. + /// This implicit constructor is provided to each build an eager Value for a + /// constant at the current insertion point in the IR. An implicit constructor + /// allows idiomatic expressions mixing ValueHandle and literals. + ValueHandle(index_t cst); + + /// ValueHandle is a value type, use the default copy constructor. + ValueHandle(const ValueHandle &other) = default; + + /// ValueHandle is a value type, the assignment operator typechecks before + /// assigning. + ValueHandle &operator=(const ValueHandle &other); + + /// Provide a swap operator. + void swap(ValueHandle &other) { + if (this == &other) + return; + std::swap(t, other.t); + std::swap(v, other.v); + } + + /// Implicit conversion useful for automatic conversion to Container. + operator Value() const { return getValue(); } + operator bool() const { return hasValue(); } + + /// Generic mlir::Op create. This is the key to being extensible to the whole + /// of MLIR without duplicating the type system or the op definitions. + template + static ValueHandle create(Args... args); + + /// Generic mlir::Op create. This is the key to being extensible to the whole + /// of MLIR without duplicating the type system or the op definitions. + /// When non-null, the optional pointer `folder` is used to call into the + /// `createAndFold` builder method. If `folder` is null, the regular `create` + /// method is called. + template + static ValueHandle create(OperationFolder *folder, Args... args); + + /// Special case to build composed AffineApply operations. + // TODO: createOrFold when available and move inside of the `create` method. + static ValueHandle createComposedAffineApply(AffineMap map, + ArrayRef operands); + + /// Generic create for a named operation producing a single value. + static ValueHandle create(StringRef name, ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes = {}); + + bool hasValue() const { return v != nullptr; } + Value getValue() const { + assert(hasValue() && "Unexpected null value;"); + return v; + } + bool hasType() const { return t != Type(); } + Type getType() const { return t; } + + Operation *getOperation() const { + if (!v) + return nullptr; + return v->getDefiningOp(); + } + +protected: + ValueHandle() : t(), v(nullptr) {} + + Type t; + Value v; +}; + +/// An OperationHandle can be used in lieu of ValueHandle to capture the +/// operation in cases when one does not care about, or cannot extract, a +/// unique Value from the operation. +/// This can be used for capturing zero result operations as well as +/// multi-result operations that are not supported by ValueHandle. +/// We do not distinguish further between zero and multi-result operations at +/// this time. +struct OperationHandle : public CapturableHandle { + OperationHandle() : op(nullptr) {} + OperationHandle(Operation *op) : op(op) {} + + OperationHandle(const OperationHandle &) = default; + OperationHandle &operator=(const OperationHandle &) = default; + + /// Generic mlir::Op create. This is the key to being extensible to the whole + /// of MLIR without duplicating the type system or the op definitions. + template + static OperationHandle create(Args... args); + template static Op createOp(Args... args); + + /// Generic create for a named operation. + static OperationHandle create(StringRef name, ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes = {}); + + operator Operation *() { return op; } + Operation *getOperation() const { return op; } + +private: + Operation *op; +}; + +/// Simple wrapper to build a generic operation without successor blocks. +template struct CustomOperation { + CustomOperation(StringRef name) : name(name) { + static_assert(std::is_same() || + std::is_same(), + "Only CustomOperation or " + "CustomOperation can be constructed."); + } + HandleType operator()(ArrayRef operands = {}, + ArrayRef resultTypes = {}, + ArrayRef attributes = {}) { + return HandleType::create(name, operands, resultTypes, attributes); + } + std::string name; +}; + +/// A BlockHandle represents a (potentially "delayed") Block abstraction. +/// This extra abstraction is necessary because an mlir::Block is not an +/// mlir::Value. +/// A BlockHandle should be captured by pointer but otherwise passed by Value +/// everywhere. +class BlockHandle : public CapturableHandle { +public: + /// A BlockHandle constructed without an mlir::Block* represents a "delayed" + /// Block. A delayed Block represents the declaration (in the PL sense) of a + /// placeholder for an mlir::Block* that will be constructed and captured at + /// some later point in the program. + BlockHandle() : block(nullptr) {} + + /// A BlockHandle constructed with an mlir::Block* represents an "eager" + /// Block. An eager Block represents both the declaration and the definition + /// (in the PL sense) of a placeholder for an mlir::Block* that has already + /// been constructed in the past and that is captured "now" in the program. + BlockHandle(mlir::Block *block) : block(block) {} + + /// BlockHandle is a value type, use the default copy constructor and + /// assignment operator. + BlockHandle(const BlockHandle &) = default; + BlockHandle &operator=(const BlockHandle &) = default; + + /// Delegates block creation to MLIR and wrap the resulting mlir::Block. + static BlockHandle create(ArrayRef argTypes); + + operator bool() { return block != nullptr; } + operator mlir::Block *() { return block; } + mlir::Block *getBlock() { return block; } + +private: + mlir::Block *block; +}; + +template +OperationHandle OperationHandle::create(Args... args) { + return OperationHandle(ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), args...) + .getOperation()); +} + +template +Op OperationHandle::createOp(Args... args) { + return cast( + OperationHandle(ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), args...) + .getOperation()) + .getOperation()); +} + +template +ValueHandle ValueHandle::create(Args... args) { + Operation *op = ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), args...) + .getOperation(); + if (op->getNumResults() == 1) { + return ValueHandle(op->getResult(0)); + } else if (op->getNumResults() == 0) { + if (auto f = dyn_cast(op)) { + return ValueHandle(f.getInductionVar()); + } + } + llvm_unreachable("unsupported operation, use an OperationHandle instead"); +} + +template +ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) { + return folder ? ValueHandle(folder->create(ScopedContext::getBuilder(), + ScopedContext::getLocation(), + args...)) + : ValueHandle(ScopedContext::getBuilder().create( + ScopedContext::getLocation(), args...)); +} + +namespace op { + +ValueHandle operator+(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator-(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator*(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator/(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator%(ValueHandle lhs, ValueHandle rhs); +ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs); +ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs); + +ValueHandle operator!(ValueHandle value); +ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator||(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator^(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator==(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator<(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator>(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs); + +} // namespace op + +/// Entry point to build multiple ValueHandle from a `Container` of Value or +/// Type. +template +inline SmallVector makeValueHandles(Container values) { + SmallVector res; + res.reserve(values.size()); + for (auto v : values) + res.push_back(ValueHandle(v)); + return res; +} + +} // namespace edsc +} // namespace mlir + +#endif // MLIR_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..a7c0365225a750ead4f3105cad017e9d0e910104 --- /dev/null +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -0,0 +1,258 @@ +//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides helper classes and syntactic sugar for declarative builders. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDSC_HELPERS_H_ +#define MLIR_EDSC_HELPERS_H_ + +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { + +// A TemplatedIndexedValue brings an index notation over the template Load and +// Store parameters. +template class TemplatedIndexedValue; + +// By default, edsc::IndexedValue provides an index notation around the affine +// load and stores. edsc::StdIndexedValue provides the standard load/store +// counterpart. +using IndexedValue = + TemplatedIndexedValue; +using StdIndexedValue = + TemplatedIndexedValue; + +// Base class for MemRefView and VectorView. +class View { +public: + unsigned rank() const { return lbs.size(); } + ValueHandle lb(unsigned idx) { return lbs[idx]; } + ValueHandle ub(unsigned idx) { return ubs[idx]; } + int64_t step(unsigned idx) { return steps[idx]; } + std::tuple range(unsigned idx) { + return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); + } + void swapRanges(unsigned i, unsigned j) { + if (i == j) + return; + lbs[i].swap(lbs[j]); + ubs[i].swap(ubs[j]); + std::swap(steps[i], steps[j]); + } + + ArrayRef getLbs() { return lbs; } + ArrayRef getUbs() { return ubs; } + ArrayRef getSteps() { return steps; } + +protected: + SmallVector lbs; + SmallVector ubs; + SmallVector steps; +}; + +/// A MemRefView represents the information required to step through a +/// MemRef. It has placeholders for non-contiguous tensors that fit within the +/// Fortran subarray model. +/// At the moment it can only capture a MemRef with an identity layout map. +// TODO(ntv): Support MemRefs with layoutMaps. +class MemRefView : public View { +public: + explicit MemRefView(Value v); + MemRefView(const MemRefView &) = default; + MemRefView &operator=(const MemRefView &) = default; + + unsigned fastestVarying() const { return rank() - 1; } + +private: + friend IndexedValue; + ValueHandle base; +}; + +/// A VectorView represents the information required to step through a +/// Vector accessing each scalar element at a time. It is the counterpart of +/// a MemRefView but for vectors. This exists purely for boilerplate avoidance. +class VectorView : public View { +public: + explicit VectorView(Value v); + VectorView(const VectorView &) = default; + VectorView &operator=(const VectorView &) = default; + +private: + friend IndexedValue; + ValueHandle base; +}; + +/// A TemplatedIndexedValue brings an index notation over the template Load and +/// Store parameters. This helper class is an abstraction purely for sugaring +/// purposes and allows writing compact expressions such as: +/// +/// ```mlir +/// // `IndexedValue` provided by default in the mlir::edsc namespace. +/// using IndexedValue = +/// TemplatedIndexedValue; +/// IndexedValue A(...), B(...), C(...); +/// For(ivs, zeros, shapeA, ones, { +/// C(ivs) = A(ivs) + B(ivs) +/// }); +/// ``` +/// +/// Assigning to an IndexedValue emits an actual `Store` operation, while +/// converting an IndexedValue to a ValueHandle emits an actual `Load` +/// operation. +template class TemplatedIndexedValue { +public: + explicit TemplatedIndexedValue(Type t) : base(t) {} + explicit TemplatedIndexedValue(Value v) + : TemplatedIndexedValue(ValueHandle(v)) {} + explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} + + TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; + + TemplatedIndexedValue operator()() { return *this; } + /// Returns a new `TemplatedIndexedValue`. + TemplatedIndexedValue operator()(ValueHandle index) { + TemplatedIndexedValue res(base); + res.indices.push_back(index); + return res; + } + template + TemplatedIndexedValue operator()(ValueHandle index, Args... indices) { + return TemplatedIndexedValue(base, index).append(indices...); + } + TemplatedIndexedValue operator()(ArrayRef indices) { + return TemplatedIndexedValue(base, indices); + } + TemplatedIndexedValue operator()(ArrayRef indices) { + return TemplatedIndexedValue( + base, ArrayRef(indices.begin(), indices.end())); + } + + /// Emits a `store`. + // NOLINTNEXTLINE: unconventional-assign-operator + OperationHandle operator=(const TemplatedIndexedValue &rhs) { + ValueHandle rrhs(rhs); + return Store(rrhs, getBase(), {indices.begin(), indices.end()}); + } + // NOLINTNEXTLINE: unconventional-assign-operator + OperationHandle operator=(ValueHandle rhs) { + return Store(rhs, getBase(), {indices.begin(), indices.end()}); + } + + /// Emits a `load` when converting to a ValueHandle. + operator ValueHandle() const { + return Load(getBase(), {indices.begin(), indices.end()}); + } + + /// Emits a `load` when converting to a Value. + Value operator*(void) const { + return Load(getBase(), {indices.begin(), indices.end()}).getValue(); + } + + ValueHandle getBase() const { return base; } + + /// Operator overloadings. + ValueHandle operator+(ValueHandle e); + ValueHandle operator-(ValueHandle e); + ValueHandle operator*(ValueHandle e); + ValueHandle operator/(ValueHandle e); + OperationHandle operator+=(ValueHandle e); + OperationHandle operator-=(ValueHandle e); + OperationHandle operator*=(ValueHandle e); + OperationHandle operator/=(ValueHandle e); + ValueHandle operator+(TemplatedIndexedValue e) { + return *this + static_cast(e); + } + ValueHandle operator-(TemplatedIndexedValue e) { + return *this - static_cast(e); + } + ValueHandle operator*(TemplatedIndexedValue e) { + return *this * static_cast(e); + } + ValueHandle operator/(TemplatedIndexedValue e) { + return *this / static_cast(e); + } + OperationHandle operator+=(TemplatedIndexedValue e) { + return this->operator+=(static_cast(e)); + } + OperationHandle operator-=(TemplatedIndexedValue e) { + return this->operator-=(static_cast(e)); + } + OperationHandle operator*=(TemplatedIndexedValue e) { + return this->operator*=(static_cast(e)); + } + OperationHandle operator/=(TemplatedIndexedValue e) { + return this->operator/=(static_cast(e)); + } + +private: + TemplatedIndexedValue(ValueHandle base, ArrayRef indices) + : base(base), indices(indices.begin(), indices.end()) {} + + TemplatedIndexedValue &append() { return *this; } + + template + TemplatedIndexedValue &append(T index, Args... indices) { + this->indices.push_back(static_cast(index)); + append(indices...); + return *this; + } + ValueHandle base; + SmallVector indices; +}; + +/// Operator overloadings. +template +ValueHandle TemplatedIndexedValue::operator+(ValueHandle e) { + using op::operator+; + return static_cast(*this) + e; +} +template +ValueHandle TemplatedIndexedValue::operator-(ValueHandle e) { + using op::operator-; + return static_cast(*this) - e; +} +template +ValueHandle TemplatedIndexedValue::operator*(ValueHandle e) { + using op::operator*; + return static_cast(*this) * e; +} +template +ValueHandle TemplatedIndexedValue::operator/(ValueHandle e) { + using op::operator/; + return static_cast(*this) / e; +} + +template +OperationHandle TemplatedIndexedValue::operator+=(ValueHandle e) { + using op::operator+; + return Store(*this + e, getBase(), {indices.begin(), indices.end()}); +} +template +OperationHandle TemplatedIndexedValue::operator-=(ValueHandle e) { + using op::operator-; + return Store(*this - e, getBase(), {indices.begin(), indices.end()}); +} +template +OperationHandle TemplatedIndexedValue::operator*=(ValueHandle e) { + using op::operator*; + return Store(*this * e, getBase(), {indices.begin(), indices.end()}); +} +template +OperationHandle TemplatedIndexedValue::operator/=(ValueHandle e) { + using op::operator/; + return Store(*this / e, getBase(), {indices.begin(), indices.end()}); +} + +} // namespace edsc +} // namespace mlir + +#endif // MLIR_EDSC_HELPERS_H_ diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..30cce6bb8d6152d93e0b3d016b65a3c626412df7 --- /dev/null +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -0,0 +1,276 @@ +//===- Intrinsics.h - MLIR Operations for Declarative Builders ---*- C++-*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable intrinsics for building snippets of MLIR +// declaratively +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDSC_INTRINSICS_H_ +#define MLIR_EDSC_INTRINSICS_H_ + +#include "mlir/EDSC/Builders.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { + +class MemRefType; +class Type; + +namespace edsc { + +/// An IndexHandle is a simple wrapper around a ValueHandle. +/// IndexHandles are ubiquitous enough to justify a new type to allow simple +/// declarations without boilerplate such as: +/// +/// ```c++ +/// IndexHandle i, j, k; +/// ``` +struct IndexHandle : public ValueHandle { + explicit IndexHandle() + : ValueHandle(ScopedContext::getBuilder().getIndexType()) {} + explicit IndexHandle(index_t v) : ValueHandle(v) {} + explicit IndexHandle(Value v) : ValueHandle(v) { + assert(v->getType() == ScopedContext::getBuilder().getIndexType() && + "Expected index type"); + } + explicit IndexHandle(ValueHandle v) : ValueHandle(v) { + assert(v.getType() == ScopedContext::getBuilder().getIndexType() && + "Expected index type"); + } + IndexHandle &operator=(const ValueHandle &v) { + assert(v.getType() == ScopedContext::getBuilder().getIndexType() && + "Expected index type"); + /// Creating a new IndexHandle(v) and then std::swap rightly complains the + /// binding has already occurred and that we should use another name. + this->t = v.getType(); + this->v = v.getValue(); + return *this; + } +}; + +inline SmallVector makeIndexHandles(unsigned rank) { + return SmallVector(rank); +} + +/// Entry point to build multiple ValueHandle* from a mutable list `ivs` of T. +template +inline SmallVector +makeHandlePointers(MutableArrayRef ivs) { + SmallVector pivs; + pivs.reserve(ivs.size()); + for (auto &iv : ivs) { + pivs.push_back(&iv); + } + return pivs; +} + +/// Returns a vector of the underlying Value from `ivs`. +inline SmallVector extractValues(ArrayRef ivs) { + SmallVector vals; + vals.reserve(ivs.size()); + for (auto &iv : ivs) { + vals.push_back(iv.getValue()); + } + return vals; +} + +/// Provides a set of first class intrinsics. +/// In the future, most of intrinsics related to Operation that don't contain +/// other operations should be Tablegen'd. +namespace intrinsics { +namespace detail { +/// Helper structure to be used with ValueBuilder / OperationBuilder. +/// It serves the purpose of removing boilerplate specialization for the sole +/// purpose of implicitly converting ArrayRef -> ArrayRef. +class ValueHandleArray { +public: + ValueHandleArray(ArrayRef vals) { + values.append(vals.begin(), vals.end()); + } + ValueHandleArray(ArrayRef vals) { + values.append(vals.begin(), vals.end()); + } + ValueHandleArray(ArrayRef vals) { + SmallVector tmp(vals.begin(), vals.end()); + values.append(tmp.begin(), tmp.end()); + } + operator ArrayRef() { return values; } + +private: + ValueHandleArray() = default; + SmallVector values; +}; + +template inline T unpack(T value) { return value; } + +inline detail::ValueHandleArray unpack(ArrayRef values) { + return detail::ValueHandleArray(values); +} + +} // namespace detail + +/// Helper variadic abstraction to allow extending to any MLIR op without +/// boilerplate or Tablegen. +/// Arguably a builder is not a ValueHandle but in practice it is only used as +/// an alias to a notional ValueHandle. +/// Implementing it as a subclass allows it to compose all the way to Value. +/// Without subclassing, implicit conversion to Value would fail when composing +/// in patterns such as: `select(a, b, select(c, d, e))`. +template struct ValueBuilder : public ValueHandle { + // Builder-based + template + ValueBuilder(Args... args) + : ValueHandle(ValueHandle::create(detail::unpack(args)...)) {} + ValueBuilder(ArrayRef vs) + : ValueBuilder(ValueBuilder::create(detail::unpack(vs))) {} + template + ValueBuilder(ArrayRef vs, Args... args) + : ValueHandle(ValueHandle::create(detail::unpack(vs), + detail::unpack(args)...)) {} + template + ValueBuilder(T t, ArrayRef vs, Args... args) + : ValueHandle(ValueHandle::create( + detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {} + template + ValueBuilder(T1 t1, T2 t2, ArrayRef vs, Args... args) + : ValueHandle(ValueHandle::create( + detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), + detail::unpack(args)...)) {} + + /// Folder-based + template + ValueBuilder(OperationFolder *folder, Args... args) + : ValueHandle(ValueHandle::create(folder, detail::unpack(args)...)) {} + ValueBuilder(OperationFolder *folder, ArrayRef vs) + : ValueBuilder(ValueBuilder::create(folder, detail::unpack(vs))) {} + template + ValueBuilder(OperationFolder *folder, ArrayRef vs, Args... args) + : ValueHandle(ValueHandle::create(folder, detail::unpack(vs), + detail::unpack(args)...)) {} + template + ValueBuilder(OperationFolder *folder, T t, ArrayRef vs, + Args... args) + : ValueHandle(ValueHandle::create(folder, detail::unpack(t), + detail::unpack(vs), + detail::unpack(args)...)) {} + template + ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef vs, + Args... args) + : ValueHandle(ValueHandle::create( + folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), + detail::unpack(args)...)) {} + + ValueBuilder() : ValueHandle(ValueHandle::create()) {} +}; + +template struct OperationBuilder : public OperationHandle { + template + OperationBuilder(Args... args) + : OperationHandle(OperationHandle::create(detail::unpack(args)...)) {} + OperationBuilder(ArrayRef vs) + : OperationHandle(OperationHandle::create(detail::unpack(vs))) {} + template + OperationBuilder(ArrayRef vs, Args... args) + : OperationHandle(OperationHandle::create(detail::unpack(vs), + detail::unpack(args)...)) {} + template + OperationBuilder(T t, ArrayRef vs, Args... args) + : OperationHandle(OperationHandle::create( + detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {} + template + OperationBuilder(T1 t1, T2 t2, ArrayRef vs, Args... args) + : OperationHandle(OperationHandle::create( + detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), + detail::unpack(args)...)) {} + OperationBuilder() : OperationHandle(OperationHandle::create()) {} +}; + +using addf = ValueBuilder; +using affine_apply = ValueBuilder; +using affine_if = OperationBuilder; +using affine_load = ValueBuilder; +using affine_store = OperationBuilder; +using alloc = ValueBuilder; +using call = OperationBuilder; +using constant_float = ValueBuilder; +using constant_index = ValueBuilder; +using constant_int = ValueBuilder; +using dealloc = OperationBuilder; +using dim = ValueBuilder; +using muli = ValueBuilder; +using mulf = ValueBuilder; +using memref_cast = ValueBuilder; +using ret = OperationBuilder; +using select = ValueBuilder; +using std_load = ValueBuilder; +using std_store = OperationBuilder; +using subi = ValueBuilder; +using tanh = ValueBuilder; +using view = ValueBuilder; + +/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. +/// +/// Prerequisites: +/// All Handles have already captured previously constructed IR objects. +OperationHandle br(BlockHandle bh, ArrayRef operands); + +/// Creates a new mlir::Block* and branches to it from the current block. +/// Argument types are specified by `operands`. +/// Captures the new block in `bh` and the actual `operands` in `captures`. To +/// insert the new mlir::Block*, a local ScopedContext is constructed and +/// released to the current block. The branch operation is then added to the +/// new block. +/// +/// Prerequisites: +/// `b` has not yet captured an mlir::Block*. +/// No `captures` have captured any mlir::Value. +/// All `operands` have already captured an mlir::Value +/// captures.size() == operands.size() +/// captures and operands are pairwise of the same type. +OperationHandle br(BlockHandle *bh, ArrayRef captures, + ArrayRef operands); + +/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with +/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and +/// `falseOperand` if `cond` evaluates to `false`). +/// +/// Prerequisites: +/// All Handles have captured previously constructed IR objects. +OperationHandle cond_br(ValueHandle cond, BlockHandle trueBranch, + ArrayRef trueOperands, + BlockHandle falseBranch, + ArrayRef falseOperands); + +/// Eagerly creates new mlir::Block* with argument types specified by +/// `trueOperands`/`falseOperands`. +/// Captures the new blocks in `trueBranch`/`falseBranch` and the arguments in +/// `trueCaptures/falseCaptures`. +/// To insert the new mlir::Block*, a local ScopedContext is constructed and +/// released. The branch operation is then added in the original location and +/// targeting the eagerly constructed blocks. +/// +/// Prerequisites: +/// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*. +/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value. +/// All `trueOperands`/`trueOperands` have already captured an mlir::Value +/// `trueCaptures`.size() == `trueOperands`.size() +/// `falseCaptures`.size() == `falseOperands`.size() +/// `trueCaptures` and `trueOperands` are pairwise of the same type +/// `falseCaptures` and `falseOperands` are pairwise of the same type. +OperationHandle cond_br(ValueHandle cond, BlockHandle *trueBranch, + ArrayRef trueCaptures, + ArrayRef trueOperands, + BlockHandle *falseBranch, + ArrayRef falseCaptures, + ArrayRef falseOperands); +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..4f218bd0d9b40963650939e68a2a13e7f6f04fe9 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -0,0 +1,126 @@ +//===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a JIT-backed execution engine for MLIR modules. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_ +#define MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ExecutionEngine/ObjectCache.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Error.h" + +#include +#include + +namespace llvm { +template class Expected; +class Module; +class ExecutionEngine; +class MemoryBuffer; +} // namespace llvm + +namespace mlir { + +class ModuleOp; + +/// A simple object cache following Lang's LLJITWithObjectCache example. +class SimpleObjectCache : public llvm::ObjectCache { +public: + void notifyObjectCompiled(const llvm::Module *M, + llvm::MemoryBufferRef ObjBuffer) override; + std::unique_ptr getObject(const llvm::Module *M) override; + + /// Dump cached object to output file `filename`. + void dumpToObjectFile(StringRef filename); + +private: + llvm::StringMap> cachedObjects; +}; + +/// JIT-backed execution engine for MLIR modules. Assumes the module can be +/// converted to LLVM IR. For each function, creates a wrapper function with +/// the fixed interface +/// +/// void _mlir_funcName(void **) +/// +/// where the only argument is interpreted as a list of pointers to the actual +/// arguments of the function, followed by a pointer to the result. This allows +/// the engine to provide the caller with a generic function pointer that can +/// be used to invoke the JIT-compiled function. +class ExecutionEngine { +public: + ExecutionEngine(bool enableObjectCache); + + /// Creates an execution engine for the given module. If `transformer` is + /// provided, it will be called on the LLVM module during JIT-compilation and + /// can be used, e.g., for reporting or optimization. `jitCodeGenOptLevel`, + /// when provided, is used as the optimization level for target code + /// generation. If `sharedLibPaths` are provided, the underlying + /// JIT-compilation will open and link the shared libraries for symbol + /// resolution. If `objectCache` is provided, JIT compiler will use it to + /// store the object generated for the given module. + static llvm::Expected> create( + ModuleOp m, std::function transformer = {}, + Optional jitCodeGenOptLevel = llvm::None, + ArrayRef sharedLibPaths = {}, bool enableObjectCache = false); + + /// Looks up a packed-argument function with the given name and returns a + /// pointer to it. Propagates errors in case of failure. + llvm::Expected lookup(StringRef name) const; + + /// Invokes the function with the given name passing it the list of arguments. + /// The arguments are accepted by lvalue-reference since the packed function + /// interface expects a list of non-null pointers. + template + llvm::Error invoke(StringRef name, Args &... args); + + /// Invokes the function with the given name passing it the list of arguments + /// as a list of opaque pointers. This is the arity-agnostic equivalent of + /// the templated `invoke`. + llvm::Error invoke(StringRef name, MutableArrayRef args); + + /// Set the target triple on the module. This is implicitly done when creating + /// the engine. + static bool setupTargetTriple(llvm::Module *llvmModule); + + /// Dump object code to output file `filename`. + void dumpToObjectFile(StringRef filename); + +private: + // Ordering of llvmContext and jit is important for destruction purposes: the + // jit must be destroyed before the context. + llvm::LLVMContext llvmContext; + + // Underlying LLJIT. + std::unique_ptr jit; + + // Underlying cache. + std::unique_ptr cache; +}; + +template +llvm::Error ExecutionEngine::invoke(StringRef name, Args &... args) { + auto expectedFPtr = lookup(name); + if (!expectedFPtr) + return expectedFPtr.takeError(); + auto fptr = *expectedFPtr; + + SmallVector packedArgs{static_cast(&args)...}; + (*fptr)(packedArgs.data()); + + return llvm::Error::success(); +} + +} // end namespace mlir + +#endif // MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_ diff --git a/mlir/include/mlir/ExecutionEngine/OptUtils.h b/mlir/include/mlir/ExecutionEngine/OptUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7b7b2598db5bdada4a5ba3cffa2b4d6189d2e083 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/OptUtils.h @@ -0,0 +1,57 @@ +//===- OptUtils.h - MLIR Execution Engine opt pass utilities ----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the utility functions to trigger LLVM optimizations from +// MLIR Execution Engine. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_OPTUTILS_H_ +#define MLIR_EXECUTIONENGINE_OPTUTILS_H_ + +#include "llvm/Pass.h" + +#include +#include + +namespace llvm { +class Module; +class Error; +class TargetMachine; +} // namespace llvm + +namespace mlir { + +/// Initialize LLVM passes that can be when running MLIR code using +/// ExecutionEngine. +void initializeLLVMPasses(); + +/// Create a module transformer function for MLIR ExecutionEngine that runs +/// LLVM IR passes corresponding to the given speed and size optimization +/// levels (e.g. -O2 or -Os). If not null, `targetMachine` is used to +/// initialize passes that provide target-specific information to the LLVM +/// optimizer. `targetMachine` must outlive the returned std::function. +std::function +makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, + llvm::TargetMachine *targetMachine); + +/// Create a module transformer function for MLIR ExecutionEngine that runs +/// LLVM IR passes explicitly specified, plus an optional optimization level, +/// Any optimization passes, if present, will be inserted before the pass at +/// position optPassesInsertPos. If not null, `targetMachine` is used to +/// initialize passes that provide target-specific information to the LLVM +/// optimizer. `targetMachine` must outlive the returned std::function. +std::function +makeLLVMPassesTransformer(llvm::ArrayRef llvmPasses, + llvm::Optional mbOptLevel, + llvm::TargetMachine *targetMachine, + unsigned optPassesInsertPos = 0); + +} // end namespace mlir + +#endif // LIR_EXECUTIONENGINE_OPTUTILS_H_ diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h new file mode 100644 index 0000000000000000000000000000000000000000..7059489ed4c9eda91ecd423f5a3652baf79be0e5 --- /dev/null +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -0,0 +1,321 @@ +//===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// An affine expression is an affine combination of dimension identifiers and +// symbols, including ceildiv/floordiv/mod by a constant integer. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_AFFINE_EXPR_H +#define MLIR_IR_AFFINE_EXPR_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/Support/Casting.h" +#include + +namespace mlir { + +class MLIRContext; +class AffineMap; +class IntegerSet; + +namespace detail { + +struct AffineExprStorage; +struct AffineBinaryOpExprStorage; +struct AffineDimExprStorage; +struct AffineSymbolExprStorage; +struct AffineConstantExprStorage; + +} // namespace detail + +enum class AffineExprKind { + Add, + /// RHS of mul is always a constant or a symbolic expression. + Mul, + /// RHS of mod is always a constant or a symbolic expression with a positive + /// value. + Mod, + /// RHS of floordiv is always a constant or a symbolic expression. + FloorDiv, + /// RHS of ceildiv is always a constant or a symbolic expression. + CeilDiv, + + /// This is a marker for the last affine binary op. The range of binary + /// op's is expected to be this element and earlier. + LAST_AFFINE_BINARY_OP = CeilDiv, + + /// Constant integer. + Constant, + /// Dimensional identifier. + DimId, + /// Symbolic identifier. + SymbolId, +}; + +/// Base type for affine expression. +/// AffineExpr's are immutable value types with intuitive operators to +/// operate on chainable, lightweight compositions. +/// An AffineExpr is an interface to the underlying storage type pointer. +class AffineExpr { +public: + using ImplType = detail::AffineExprStorage; + + AffineExpr() : expr(nullptr) {} + /* implicit */ AffineExpr(const ImplType *expr) + : expr(const_cast(expr)) {} + + AffineExpr(const AffineExpr &other) : expr(other.expr) {} + AffineExpr &operator=(AffineExpr other) { + expr = other.expr; + return *this; + } + + bool operator==(AffineExpr other) const { return expr == other.expr; } + bool operator!=(AffineExpr other) const { return !(*this == other); } + bool operator==(int64_t v) const; + bool operator!=(int64_t v) const { return !(*this == v); } + explicit operator bool() const { return expr; } + + bool operator!() const { return expr == nullptr; } + + template bool isa() const; + template U dyn_cast() const; + template U cast() const; + + MLIRContext *getContext() const; + + /// Return the classification for this type. + AffineExprKind getKind() const; + + void print(raw_ostream &os) const; + void dump() const; + + /// Returns true if this expression is made out of only symbols and + /// constants, i.e., it does not involve dimensional identifiers. + bool isSymbolicOrConstant() const; + + /// Returns true if this is a pure affine expression, i.e., multiplication, + /// floordiv, ceildiv, and mod is only allowed w.r.t constants. + bool isPureAffine() const; + + /// Returns the greatest known integral divisor of this affine expression. The + /// result is always positive. + int64_t getLargestKnownDivisor() const; + + /// Return true if the affine expression is a multiple of 'factor'. + bool isMultipleOf(int64_t factor) const; + + /// Return true if the affine expression involves AffineDimExpr `position`. + bool isFunctionOfDim(unsigned position) const; + + /// Walk all of the AffineExpr's in this expression in postorder. + void walk(std::function callback) const; + + /// This method substitutes any uses of dimensions and symbols (e.g. + /// dim#0 with dimReplacements[0]) and returns the modified expression tree. + AffineExpr replaceDimsAndSymbols(ArrayRef dimReplacements, + ArrayRef symReplacements) const; + + AffineExpr operator+(int64_t v) const; + AffineExpr operator+(AffineExpr other) const; + AffineExpr operator-() const; + AffineExpr operator-(int64_t v) const; + AffineExpr operator-(AffineExpr other) const; + AffineExpr operator*(int64_t v) const; + AffineExpr operator*(AffineExpr other) const; + AffineExpr floorDiv(uint64_t v) const; + AffineExpr floorDiv(AffineExpr other) const; + AffineExpr ceilDiv(uint64_t v) const; + AffineExpr ceilDiv(AffineExpr other) const; + AffineExpr operator%(uint64_t v) const; + AffineExpr operator%(AffineExpr other) const; + + /// Compose with an AffineMap. + /// Returns the composition of this AffineExpr with `map`. + /// + /// Prerequisites: + /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of + /// `this` is smaller than the number of results of `map`. If a result of a + /// map does not have a corresponding AffineDimExpr, that result simply does + /// not appear in the produced AffineExpr. + /// + /// Example: + /// expr: `d0 + d2` + /// map: `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)` + /// returned expr: `d0 * 2 + d1 + d2 + s1` + AffineExpr compose(AffineMap map) const; + + friend ::llvm::hash_code hash_value(AffineExpr arg); + +protected: + ImplType *expr; +}; + +/// Affine binary operation expression. An affine binary operation could be an +/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is +/// represented through a multiply by -1 and add.) These expressions are always +/// constructed in a simplified form. For eg., the LHS and RHS operands can't +/// both be constants. There are additional canonicalizing rules depending on +/// the op type: see checks in the constructor. +class AffineBinaryOpExpr : public AffineExpr { +public: + using ImplType = detail::AffineBinaryOpExprStorage; + /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr); + AffineExpr getLHS() const; + AffineExpr getRHS() const; +}; + +/// A dimensional identifier appearing in an affine expression. +class AffineDimExpr : public AffineExpr { +public: + using ImplType = detail::AffineDimExprStorage; + /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr); + unsigned getPosition() const; +}; + +/// A symbolic identifier appearing in an affine expression. +class AffineSymbolExpr : public AffineExpr { +public: + using ImplType = detail::AffineDimExprStorage; + /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr); + unsigned getPosition() const; +}; + +/// An integer constant appearing in affine expression. +class AffineConstantExpr : public AffineExpr { +public: + using ImplType = detail::AffineConstantExprStorage; + /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr); + int64_t getValue() const; +}; + +/// Make AffineExpr hashable. +inline ::llvm::hash_code hash_value(AffineExpr arg) { + return ::llvm::hash_value(arg.expr); +} + +inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; } +inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; } +inline AffineExpr operator-(int64_t val, AffineExpr expr) { + return expr * (-1) + val; +} + +/// These free functions allow clients of the API to not use classes in detail. +AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context); +AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context); +AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context); +AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, + AffineExpr rhs); + +/// Constructs an affine expression from a flat ArrayRef. If there are local +/// identifiers (neither dimensional nor symbolic) that appear in the sum of +/// products expression, 'localExprs' is expected to have the AffineExpr +/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the +/// format [dims, symbols, locals, constant term]. +AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, + unsigned numSymbols, ArrayRef localExprs, + MLIRContext *context); + +raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr); + +template bool AffineExpr::isa() const { + if (std::is_same::value) { + return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP; + } + if (std::is_same::value) { + return getKind() == AffineExprKind::DimId; + } + if (std::is_same::value) { + return getKind() == AffineExprKind::SymbolId; + } + if (std::is_same::value) { + return getKind() == AffineExprKind::Constant; + } +} +template U AffineExpr::dyn_cast() const { + if (isa()) { + return U(expr); + } + return U(nullptr); +} +template U AffineExpr::cast() const { + assert(isa()); + return U(expr); +} + +/// Simplify an affine expression by flattening and some amount of +/// simple analysis. This has complexity linear in the number of nodes in +/// 'expr'. Returns the simplified expression, which is the same as the input +/// expression if it can't be simplified. +AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols); + +/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). +/// See documentation for AffineExprFlattener on how mod's and div's are +/// flattened. +bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + SmallVectorImpl *flattenedExpr); + +/// Flattens the result expressions of the map to their corresponding flattened +/// forms and set in 'flattenedExprs'. Returns true on success or false +/// if any expression in the map could not be flattened (i.e., semi-affine is +/// not yet handled). For all affine expressions that share the same operands +/// (like those of an affine map), this method should be used instead of +/// repeatedly calling getFlattenedAffineExpr since local variables added to +/// deal with div's and mod's will be reused across expressions. +bool getFlattenedAffineExprs( + AffineMap map, std::vector> *flattenedExprs); +bool getFlattenedAffineExprs( + IntegerSet set, std::vector> *flattenedExprs); + +namespace detail { +template void bindDims(MLIRContext *ctx) {} + +template +void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &... exprs) { + e = getAffineDimExpr(N, ctx); + bindDims(ctx, exprs...); +} +} // namespace detail + +/// Bind a list of AffineExpr references to DimExpr at positions: +/// [0 .. sizeof...(exprs)] +template +void bindDims(MLIRContext *ctx, AffineExprTy &... exprs) { + detail::bindDims<0>(ctx, exprs...); +} + +} // namespace mlir + +namespace llvm { + +// AffineExpr hash just like pointers +template <> struct DenseMapInfo { + static mlir::AffineExpr getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::AffineExpr(static_cast(pointer)); + } + static mlir::AffineExpr getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::AffineExpr(static_cast(pointer)); + } + static unsigned getHashValue(mlir::AffineExpr val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) { + return LHS == RHS; + } +}; + +} // namespace llvm + +#endif // MLIR_IR_AFFINE_EXPR_H diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h new file mode 100644 index 0000000000000000000000000000000000000000..7866d6bb996a86ec380db07e360dd5369a7700b6 --- /dev/null +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -0,0 +1,325 @@ +//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the AffineExpr visitor class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H +#define MLIR_IR_AFFINE_EXPR_VISITOR_H + +#include "mlir/IR/AffineExpr.h" + +namespace mlir { + +/// Base class for AffineExpr visitors/walkers. +/// +/// AffineExpr visitors are used when you want to perform different actions +/// for different kinds of AffineExprs without having to use lots of casts +/// and a big switch instruction. +/// +/// To define your own visitor, inherit from this class, specifying your +/// new type for the 'SubClass' template parameter, and "override" visitXXX +/// functions in your class. This class is defined in terms of statically +/// resolved overloading, not virtual functions. +/// +/// For example, here is a visitor that counts the number of for AffineDimExprs +/// in an AffineExpr. +/// +/// /// Declare the class. Note that we derive from AffineExprVisitor +/// /// instantiated with our new subclasses_ type. +/// +/// struct DimExprCounter : public AffineExprVisitor { +/// unsigned numDimExprs; +/// DimExprCounter() : numDimExprs(0) {} +/// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; } +/// }; +/// +/// And this class would be used like this: +/// DimExprCounter dec; +/// dec.visit(affineExpr); +/// numDimExprs = dec.numDimExprs; +/// +/// AffineExprVisitor provides visit methods for the following binary affine +/// op expressions: +/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr, +/// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr, +/// AffineBinaryCeilDivOpExpr. Note that default implementations of these +/// methods will call the general AffineBinaryOpExpr method. +/// +/// In addition, visit methods are provided for the following affine +// expressions: AffineConstantExpr, AffineDimExpr, and +// AffineSymbolExpr. +/// +/// Note that if you don't implement visitXXX for some affine expression type, +/// the visitXXX method for Instruction superclass will be invoked. +/// +/// Note that this class is specifically designed as a template to avoid +/// virtual function call overhead. Defining and using a AffineExprVisitor is +/// just as efficient as having your own switch instruction over the instruction +/// opcode. + +template class AffineExprVisitor { + //===--------------------------------------------------------------------===// + // Interface code - This is the public interface of the AffineExprVisitor + // that you use to visit affine expressions... +public: + // Function to walk an AffineExpr (in post order). + RetTy walkPostOrder(AffineExpr expr) { + static_assert(std::is_base_of::value, + "Must instantiate with a derived type of AffineExprVisitor"); + switch (expr.getKind()) { + case AffineExprKind::Add: { + auto binOpExpr = expr.cast(); + walkOperandsPostOrder(binOpExpr); + return static_cast(this)->visitAddExpr(binOpExpr); + } + case AffineExprKind::Mul: { + auto binOpExpr = expr.cast(); + walkOperandsPostOrder(binOpExpr); + return static_cast(this)->visitMulExpr(binOpExpr); + } + case AffineExprKind::Mod: { + auto binOpExpr = expr.cast(); + walkOperandsPostOrder(binOpExpr); + return static_cast(this)->visitModExpr(binOpExpr); + } + case AffineExprKind::FloorDiv: { + auto binOpExpr = expr.cast(); + walkOperandsPostOrder(binOpExpr); + return static_cast(this)->visitFloorDivExpr(binOpExpr); + } + case AffineExprKind::CeilDiv: { + auto binOpExpr = expr.cast(); + walkOperandsPostOrder(binOpExpr); + return static_cast(this)->visitCeilDivExpr(binOpExpr); + } + case AffineExprKind::Constant: + return static_cast(this)->visitConstantExpr( + expr.cast()); + case AffineExprKind::DimId: + return static_cast(this)->visitDimExpr( + expr.cast()); + case AffineExprKind::SymbolId: + return static_cast(this)->visitSymbolExpr( + expr.cast()); + } + } + + // Function to visit an AffineExpr. + RetTy visit(AffineExpr expr) { + static_assert(std::is_base_of::value, + "Must instantiate with a derived type of AffineExprVisitor"); + switch (expr.getKind()) { + case AffineExprKind::Add: { + auto binOpExpr = expr.cast(); + return static_cast(this)->visitAddExpr(binOpExpr); + } + case AffineExprKind::Mul: { + auto binOpExpr = expr.cast(); + return static_cast(this)->visitMulExpr(binOpExpr); + } + case AffineExprKind::Mod: { + auto binOpExpr = expr.cast(); + return static_cast(this)->visitModExpr(binOpExpr); + } + case AffineExprKind::FloorDiv: { + auto binOpExpr = expr.cast(); + return static_cast(this)->visitFloorDivExpr(binOpExpr); + } + case AffineExprKind::CeilDiv: { + auto binOpExpr = expr.cast(); + return static_cast(this)->visitCeilDivExpr(binOpExpr); + } + case AffineExprKind::Constant: + return static_cast(this)->visitConstantExpr( + expr.cast()); + case AffineExprKind::DimId: + return static_cast(this)->visitDimExpr( + expr.cast()); + case AffineExprKind::SymbolId: + return static_cast(this)->visitSymbolExpr( + expr.cast()); + } + llvm_unreachable("Unknown AffineExpr"); + } + + //===--------------------------------------------------------------------===// + // Visitation functions... these functions provide default fallbacks in case + // the user does not specify what to do for a particular instruction type. + // The default behavior is to generalize the instruction type to its subtype + // and try visiting the subtype. All of this should be inlined perfectly, + // because there are no virtual functions to get in the way. + // + + // Default visit methods. Note that the default op-specific binary op visit + // methods call the general visitAffineBinaryOpExpr visit method. + void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {} + void visitAddExpr(AffineBinaryOpExpr expr) { + static_cast(this)->visitAffineBinaryOpExpr(expr); + } + void visitMulExpr(AffineBinaryOpExpr expr) { + static_cast(this)->visitAffineBinaryOpExpr(expr); + } + void visitModExpr(AffineBinaryOpExpr expr) { + static_cast(this)->visitAffineBinaryOpExpr(expr); + } + void visitFloorDivExpr(AffineBinaryOpExpr expr) { + static_cast(this)->visitAffineBinaryOpExpr(expr); + } + void visitCeilDivExpr(AffineBinaryOpExpr expr) { + static_cast(this)->visitAffineBinaryOpExpr(expr); + } + void visitConstantExpr(AffineConstantExpr expr) {} + void visitDimExpr(AffineDimExpr expr) {} + void visitSymbolExpr(AffineSymbolExpr expr) {} + +private: + // Walk the operands - each operand is itself walked in post order. + void walkOperandsPostOrder(AffineBinaryOpExpr expr) { + walkPostOrder(expr.getLHS()); + walkPostOrder(expr.getRHS()); + } +}; + +// This class is used to flatten a pure affine expression (AffineExpr, +// which is in a tree form) into a sum of products (w.r.t constants) when +// possible, and in that process simplifying the expression. For a modulo, +// floordiv, or a ceildiv expression, an additional identifier, called a local +// identifier, is introduced to rewrite the expression as a sum of product +// affine expression. Each local identifier is always and by construction a +// floordiv of a pure add/mul affine function of dimensional, symbolic, and +// other local identifiers, in a non-mutually recursive way. Hence, every local +// identifier can ultimately always be recovered as an affine function of +// dimensional and symbolic identifiers (involving floordiv's); note however +// that by AffineExpr construction, some floordiv combinations are converted to +// mod's. The result of the flattening is a flattened expression and a set of +// constraints involving just the local variables. +// +// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local +// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. +// +// The simplification performed includes the accumulation of contributions for +// each dimensional and symbolic identifier together, the simplification of +// floordiv/ceildiv/mod expressions and other simplifications that in turn +// happen as a result. A simplification that this flattening naturally performs +// is of simplifying the numerator and denominator of floordiv/ceildiv, and +// folding a modulo expression to a zero, if possible. Three examples are below: +// +// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 +// (d0 - d0 mod 4 + 4) mod 4 simplified to 0 +// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 +// +// The way the flattening works for the second example is as follows: d0 % 4 is +// replaced by d0 - 4*q with q being introduced: the expression then simplifies +// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to +// zero. Note that an affine expression may not always be expressible purely as +// a sum of products involving just the original dimensional and symbolic +// identifiers due to the presence of modulo/floordiv/ceildiv expressions that +// may not be eliminated after simplification; in such cases, the final +// expression can be reconstructed by replacing the local identifiers with their +// corresponding explicit form stored in 'localExprs' (note that each of the +// explicit forms itself would have been simplified). +// +// The expression walk method here performs a linear time post order walk that +// performs the above simplifications through visit methods, with partial +// results being stored in 'operandExprStack'. When a parent expr is visited, +// the flattened expressions corresponding to its two operands would already be +// on the stack - the parent expression looks at the two flattened expressions +// and combines the two. It pops off the operand expressions and pushes the +// combined result (although this is done in-place on its LHS operand expr). +// When the walk is completed, the flattened form of the top-level expression +// would be left on the stack. +// +// A flattener can be repeatedly used for multiple affine expressions that bind +// to the same operands, for example, for all result expressions of an +// AffineMap or AffineValueMap. In such cases, using it for multiple expressions +// is more efficient than creating a new flattener for each expression since +// common identical div and mod expressions appearing across different +// expressions are mapped to the same local identifier (same column position in +// 'localVarCst'). +class SimpleAffineExprFlattener + : public AffineExprVisitor { +public: + // Flattend expression layout: [dims, symbols, locals, constant] + // Stack that holds the LHS and RHS operands while visiting a binary op expr. + // In future, consider adding a prepass to determine how big the SmallVector's + // will be, and linearize this to std::vector to prevent + // SmallVector moves on re-allocation. + std::vector> operandExprStack; + + unsigned numDims; + unsigned numSymbols; + + // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's. + unsigned numLocals; + + // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for + // which new identifiers were introduced; if the latter do not get canceled + // out, these expressions can be readily used to reconstruct the AffineExpr + // (tree) form. Note that these expressions themselves would have been + // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 + // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) + // ceildiv 2 would be the local expression stored for q. + SmallVector localExprs; + + SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols); + + virtual ~SimpleAffineExprFlattener() = default; + + // Visitor method overrides. + void visitMulExpr(AffineBinaryOpExpr expr); + void visitAddExpr(AffineBinaryOpExpr expr); + void visitDimExpr(AffineDimExpr expr); + void visitSymbolExpr(AffineSymbolExpr expr); + void visitConstantExpr(AffineConstantExpr expr); + void visitCeilDivExpr(AffineBinaryOpExpr expr); + void visitFloorDivExpr(AffineBinaryOpExpr expr); + + // + // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 + // + // A mod expression "expr mod c" is thus flattened by introducing a new local + // variable q (= expr floordiv c), such that expr mod c is replaced with + // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. + void visitModExpr(AffineBinaryOpExpr expr); + +protected: + // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). + // The local identifier added is always a floordiv of a pure add/mul affine + // function of other identifiers, coefficients of which are specified in + // dividend and with respect to a positive constant divisor. localExpr is the + // simplified tree expression (AffineExpr) corresponding to the quantifier. + virtual void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, + AffineExpr localExpr); + +private: + // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 + // A floordiv is thus flattened by introducing a new local variable q, and + // replacing that expression with 'q' while adding the constraints + // c * q <= expr <= c * q + c - 1 to localVarCst (done by + // FlatAffineConstraints::addLocalFloorDiv). + // + // A ceildiv is similarly flattened: + // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c + void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); + + int findLocalId(AffineExpr localExpr); + + inline unsigned getNumCols() const { + return numDims + numSymbols + numLocals + 1; + } + inline unsigned getConstantIndex() const { return getNumCols() - 1; } + inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } + inline unsigned getSymbolStartIndex() const { return numDims; } + inline unsigned getDimStartIndex() const { return 0; } +}; + +} // end namespace mlir + +#endif // MLIR_IR_AFFINE_EXPR_VISITOR_H diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h new file mode 100644 index 0000000000000000000000000000000000000000..3f9116cb1687c0663e9fb6dc30a1b7fa81449058 --- /dev/null +++ b/mlir/include/mlir/IR/AffineMap.h @@ -0,0 +1,251 @@ +//===- AffineMap.h - MLIR Affine Map Class ----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Affine maps are mathematical functions which map a list of dimension +// identifiers and symbols, to multidimensional affine expressions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_AFFINE_MAP_H +#define MLIR_IR_AFFINE_MAP_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMapInfo.h" + +namespace mlir { + +namespace detail { +struct AffineMapStorage; +} // end namespace detail + +class AffineExpr; +class Attribute; +struct LogicalResult; +class MLIRContext; + +/// A multi-dimensional affine map +/// Affine map's are immutable like Type's, and they are uniqued. +/// Eg: (d0, d1) -> (d0/128, d0 mod 128, d1) +/// The names used (d0, d1) don't matter - it's the mathematical function that +/// is unique to this affine map. +class AffineMap { +public: + using ImplType = detail::AffineMapStorage; + + AffineMap() : map(nullptr) {} + explicit AffineMap(ImplType *map) : map(map) {} + AffineMap(const AffineMap &other) : map(other.map) {} + AffineMap &operator=(const AffineMap &other) = default; + + /// Returns a zero result affine map with no dimensions or symbols: () -> (). + static AffineMap get(MLIRContext *context); + + static AffineMap get(unsigned dimCount, unsigned symbolCount, + ArrayRef results); + + /// Returns a single constant result affine map. + static AffineMap getConstantMap(int64_t val, MLIRContext *context); + + /// Returns an AffineMap with 'numDims' identity result dim exprs. + static AffineMap getMultiDimIdentityMap(unsigned numDims, + MLIRContext *context); + + /// Returns an AffineMap representing a permutation. + /// The permutation is expressed as a non-empty vector of integers. + /// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with + /// `permutation = [1,2,0]`. All values in `permutation` must be + /// integers, in the range 0..`permutation.size()-1` without duplications + /// (i.e. `[1,1,2]` is an invalid permutation). + static AffineMap getPermutationMap(ArrayRef permutation, + MLIRContext *context); + + MLIRContext *getContext() const; + + explicit operator bool() { return map != nullptr; } + bool operator==(AffineMap other) const { return other.map == map; } + bool operator!=(AffineMap other) const { return !(other.map == map); } + + /// Returns true if this affine map is an identity affine map. + /// An identity affine map corresponds to an identity affine function on the + /// dimensional identifiers. + bool isIdentity() const; + + /// Returns true if this affine map is an empty map, i.e., () -> (). + bool isEmpty() const; + + /// Returns true if this affine map is a single result constant function. + bool isSingleConstant() const; + + /// Returns the constant result of this map. This methods asserts that the map + /// has a single constant result. + int64_t getSingleConstantResult() const; + + // Prints affine map to 'os'. + void print(raw_ostream &os) const; + void dump() const; + + unsigned getNumDims() const; + unsigned getNumSymbols() const; + unsigned getNumResults() const; + unsigned getNumInputs() const; + + ArrayRef getResults() const; + AffineExpr getResult(unsigned idx) const; + + /// Walk all of the AffineExpr's in this mapping. Each node in an expression + /// tree is visited in postorder. + void walkExprs(std::function callback) const; + + /// This method substitutes any uses of dimensions and symbols (e.g. + /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified + /// expression mapping. Because this can be used to eliminate dims and + /// symbols, the client needs to specify the number of dims and symbols in + /// the result. The returned map always has the same number of results. + AffineMap replaceDimsAndSymbols(ArrayRef dimReplacements, + ArrayRef symReplacements, + unsigned numResultDims, + unsigned numResultSyms); + + /// Folds the results of the application of an affine map on the provided + /// operands to a constant if possible. + LogicalResult constantFold(ArrayRef operandConstants, + SmallVectorImpl &results) const; + + /// Returns the AffineMap resulting from composing `this` with `map`. + /// The resulting AffineMap has as many AffineDimExpr as `map` and as many + /// AffineSymbolExpr as the concatenation of `this` and `map` (in which case + /// the symbols of `this` map come first). + /// + /// Prerequisites: + /// The maps are composable, i.e. that the number of AffineDimExpr of `this` + /// matches the number of results of `map`. + /// + /// Example: + /// map1: `(d0, d1)[s0, s1] -> (d0 + 1 + s1, d1 - 1 - s0)` + /// map2: `(d0)[s0] -> (d0 + s0, d0 - s0)` + /// map1.compose(map2): + /// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)` + AffineMap compose(AffineMap map); + + /// Returns true if the AffineMap represents a subset (i.e. a projection) of a + /// symbol-less permutation map. + bool isProjectedPermutation(); + + /// Returns true if the AffineMap represents a symbol-less permutation map. + bool isPermutation(); + + /// Returns the map consisting of the `resultPos` subset. + AffineMap getSubMap(ArrayRef resultPos); + + friend ::llvm::hash_code hash_value(AffineMap arg); + +private: + ImplType *map; + + static AffineMap getImpl(unsigned dimCount, unsigned symbolCount, + ArrayRef results, MLIRContext *context); +}; + +// Make AffineExpr hashable. +inline ::llvm::hash_code hash_value(AffineMap arg) { + return ::llvm::hash_value(arg.map); +} + +/// Simplify an affine map by simplifying its underlying AffineExpr results. +AffineMap simplifyAffineMap(AffineMap map); + +/// Returns a map of codomain to domain dimensions such that the first codomain +/// dimension for a particular domain dimension is selected. +/// Returns an empty map if the input map is empty or if `map` is not invertible +/// (i.e. `map` does not contain a subset that is a permutation of full domain +/// rank). +/// +/// Prerequisites: +/// 1. `map` has no symbols. +/// +/// Example 1: +/// +/// ```mlir +/// (d0, d1, d2) -> (d1, d1, d0, d2, d1, d2, d1, d0) +/// 0 2 3 +/// ``` +/// +/// returns: +/// +/// ```mlir +/// (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3) +/// ``` +/// +/// Example 2: +/// +/// ```mlir +/// (d0, d1, d2) -> (d1, d0 + d1, d0, d2, d1, d2, d1, d0) +/// 0 2 3 +/// ``` +/// +/// returns: +/// +/// ```mlir +/// (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3) +/// ``` +AffineMap inversePermutation(AffineMap map); + +/// Concatenates a list of `maps` into a single AffineMap, stepping over +/// potentially empty maps. Assumes each of the underlying map has 0 symbols. +/// The resulting map has a number of dims equal to the max of `maps`' dims and +/// the concatenated results as its results. +/// Returns an empty map if all input `maps` are empty. +/// +/// Example: +/// When applied to the following list of 3 affine maps, +/// +/// ```mlir +/// { +/// (i, j, k) -> (i, k), +/// (i, j, k) -> (k, j), +/// (i, j, k) -> (i, j) +/// } +/// ``` +/// +/// Returns the map: +/// +/// ```mlir +/// (i, j, k) -> (i, k, k, j, i, j) +/// ``` +AffineMap concatAffineMaps(ArrayRef maps); + +inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { + map.print(os); + return os; +} +} // end namespace mlir + +namespace llvm { + +// AffineExpr hash just like pointers +template <> struct DenseMapInfo { + static mlir::AffineMap getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::AffineMap(static_cast(pointer)); + } + static mlir::AffineMap getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::AffineMap(static_cast(pointer)); + } + static unsigned getHashValue(mlir::AffineMap val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::AffineMap LHS, mlir::AffineMap RHS) { + return LHS == RHS; + } +}; + +} // namespace llvm + +#endif // MLIR_IR_AFFINE_MAP_H diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h new file mode 100644 index 0000000000000000000000000000000000000000..9804d6866f85f921d0eb14e3d6a2fb744b49b95b --- /dev/null +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -0,0 +1,107 @@ +//===- AttributeSupport.h ---------------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines support types for registering dialect extended attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ATTRIBUTESUPPORT_H +#define MLIR_IR_ATTRIBUTESUPPORT_H + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StorageUniquerSupport.h" +#include "llvm/ADT/PointerIntPair.h" + +namespace mlir { +class MLIRContext; +class Type; + +//===----------------------------------------------------------------------===// +// AttributeStorage +//===----------------------------------------------------------------------===// + +namespace detail { +class AttributeUniquer; +} // end namespace detail + +/// Base storage class appearing in an attribute. Derived storage classes should +/// only be constructed within the context of the AttributeUniquer. +class AttributeStorage : public StorageUniquer::BaseStorage { + friend detail::AttributeUniquer; + friend StorageUniquer; + +public: + /// Get the type of this attribute. + Type getType() const; + + /// Get the dialect of this attribute. + Dialect &getDialect() const { + assert(dialect && "Malformed attribute storage object."); + return const_cast(*dialect); + } + +protected: + /// Construct a new attribute storage instance with the given type. + /// Note: All attributes require a valid type. If no type is provided here, + /// the type of the attribute will automatically default to NoneType + /// upon initialization in the uniquer. + AttributeStorage(Type type); + AttributeStorage(); + + /// Set the type of this attribute. + void setType(Type type); + + // Set the dialect for this storage instance. This is used by the + // AttributeUniquer when initializing a newly constructed storage object. + void initializeDialect(Dialect &newDialect) { dialect = &newDialect; } + +private: + /// The dialect for this attribute. + Dialect *dialect; + + /// The opaque type of the attribute value. + const void *type; +}; + +/// Default storage type for attributes that require no additional +/// initialization or storage. +using DefaultAttributeStorage = AttributeStorage; + +//===----------------------------------------------------------------------===// +// AttributeStorageAllocator +//===----------------------------------------------------------------------===// + +// This is a utility allocator used to allocate memory for instances of derived +// Attributes. +using AttributeStorageAllocator = StorageUniquer::StorageAllocator; + +//===----------------------------------------------------------------------===// +// AttributeUniquer +//===----------------------------------------------------------------------===// +namespace detail { +// A utility class to get, or create, unique instances of attributes within an +// MLIRContext. This class manages all creation and uniquing of attributes. +class AttributeUniquer { +public: + /// Get an uniqued instance of attribute T. + template + static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { + return ctx->getAttributeUniquer().get( + getInitFn(ctx, T::getClassID()), kind, std::forward(args)...); + } + +private: + /// Returns a functor used to initialize new attribute storage instances. + static std::function + getInitFn(MLIRContext *ctx, const ClassID *const attrID); +}; +} // namespace detail + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h new file mode 100644 index 0000000000000000000000000000000000000000..b8398580f61c90ea0da96d4c2670f361168b4419 --- /dev/null +++ b/mlir/include/mlir/IR/Attributes.h @@ -0,0 +1,1440 @@ +//===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ATTRIBUTES_H +#define MLIR_IR_ATTRIBUTES_H + +#include "mlir/IR/AttributeSupport.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/Sequence.h" + +namespace mlir { +class AffineMap; +class Dialect; +class FunctionType; +class Identifier; +class IntegerSet; +class Location; +class MLIRContext; +class ShapedType; +class Type; + +namespace detail { + +struct AffineMapAttributeStorage; +struct ArrayAttributeStorage; +struct BoolAttributeStorage; +struct DictionaryAttributeStorage; +struct IntegerAttributeStorage; +struct IntegerSetAttributeStorage; +struct FloatAttributeStorage; +struct OpaqueAttributeStorage; +struct StringAttributeStorage; +struct SymbolRefAttributeStorage; +struct TypeAttributeStorage; + +/// Elements Attributes. +struct DenseElementsAttributeStorage; +struct OpaqueElementsAttributeStorage; +struct SparseElementsAttributeStorage; +} // namespace detail + +/// Attributes are known-constant values of operations and functions. +/// +/// Instances of the Attribute class are references to immutable, uniqued, +/// and immortal values owned by MLIRContext. As such, an Attribute is a thin +/// wrapper around an underlying storage pointer. Attributes are usually passed +/// by value. +class Attribute { +public: + /// Integer identifier for all the concrete attribute kinds. + enum Kind { + // Reserve attribute kinds for dialect specific extensions. +#define DEFINE_SYM_KIND_RANGE(Dialect) \ + FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff, +#include "DialectSymbolRegistry.def" + }; + + /// Utility class for implementing attributes. + template + using AttrBase = detail::StorageUserBase; + + using ImplType = AttributeStorage; + using ValueType = void; + + Attribute() : impl(nullptr) {} + /* implicit */ Attribute(const ImplType *impl) + : impl(const_cast(impl)) {} + + Attribute(const Attribute &other) = default; + Attribute &operator=(const Attribute &other) = default; + + bool operator==(Attribute other) const { return impl == other.impl; } + bool operator!=(Attribute other) const { return !(*this == other); } + explicit operator bool() const { return impl; } + + bool operator!() const { return impl == nullptr; } + + template bool isa() const; + template U dyn_cast() const; + template U dyn_cast_or_null() const; + template U cast() const; + + // Support dyn_cast'ing Attribute to itself. + static bool classof(Attribute) { return true; } + + /// Return the classification for this attribute. + unsigned getKind() const { return impl->getKind(); } + + /// Return the type of this attribute. + Type getType() const; + + /// Return the context this attribute belongs to. + MLIRContext *getContext() const; + + /// Get the dialect this attribute is registered to. + Dialect &getDialect() const; + + /// Print the attribute. + void print(raw_ostream &os) const; + void dump() const; + + /// Get an opaque pointer to the attribute. + const void *getAsOpaquePointer() const { return impl; } + /// Construct an attribute from the opaque pointer representation. + static Attribute getFromOpaquePointer(const void *ptr) { + return Attribute(reinterpret_cast(ptr)); + } + + friend ::llvm::hash_code hash_value(Attribute arg); + +protected: + ImplType *impl; +}; + +inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) { + attr.print(os); + return os; +} + +namespace StandardAttributes { +enum Kind { + AffineMap = Attribute::FIRST_STANDARD_ATTR, + Array, + Bool, + Dictionary, + Float, + Integer, + IntegerSet, + Opaque, + String, + SymbolRef, + Type, + Unit, + + /// Elements Attributes. + DenseElements, + OpaqueElements, + SparseElements, + FIRST_ELEMENTS_ATTR = DenseElements, + LAST_ELEMENTS_ATTR = SparseElements, + + /// Locations. + CallSiteLocation, + FileLineColLocation, + FusedLocation, + NameLocation, + OpaqueLocation, + UnknownLocation, + + // Represents a location as a 'void*' pointer to a front-end's opaque + // location information, which must live longer than the MLIR objects that + // refer to it. OpaqueLocation's are never serialized. + // + // TODO: OpaqueLocation, + + // Represents a value inlined through a function call. + // TODO: InlinedLocation, + + FIRST_LOCATION_ATTR = CallSiteLocation, + LAST_LOCATION_ATTR = UnknownLocation, +}; +} // namespace StandardAttributes + +//===----------------------------------------------------------------------===// +// AffineMapAttr +//===----------------------------------------------------------------------===// + +class AffineMapAttr + : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = AffineMap; + + static AffineMapAttr get(AffineMap value); + + AffineMap getValue() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::AffineMap; + } +}; + +//===----------------------------------------------------------------------===// +// ArrayAttr +//===----------------------------------------------------------------------===// + +/// Array attributes are lists of other attributes. They are not necessarily +/// type homogenous given that attributes don't, in general, carry types. +class ArrayAttr : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = ArrayRef; + + static ArrayAttr get(ArrayRef value, MLIRContext *context); + + ArrayRef getValue() const; + + /// Support range iteration. + using iterator = llvm::ArrayRef::iterator; + iterator begin() const { return getValue().begin(); } + iterator end() const { return getValue().end(); } + size_t size() const { return getValue().size(); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::Array; + } +}; + +//===----------------------------------------------------------------------===// +// BoolAttr +//===----------------------------------------------------------------------===// + +class BoolAttr : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = bool; + + static BoolAttr get(bool value, MLIRContext *context); + + bool getValue() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; } +}; + +//===----------------------------------------------------------------------===// +// DictionaryAttr +//===----------------------------------------------------------------------===// + +/// NamedAttribute is used for dictionary attributes, it holds an identifier for +/// the name and a value for the attribute. The attribute pointer should always +/// be non-null. +using NamedAttribute = std::pair; + +/// Dictionary attribute is an attribute that represents a sorted collection of +/// named attribute values. The elements are sorted by name, and each name must +/// be unique within the collection. +class DictionaryAttr + : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = ArrayRef; + + static DictionaryAttr get(ArrayRef value, + MLIRContext *context); + + ArrayRef getValue() const; + + /// Return the specified attribute if present, null otherwise. + Attribute get(StringRef name) const; + Attribute get(Identifier name) const; + + /// Support range iteration. + using iterator = llvm::ArrayRef::iterator; + iterator begin() const; + iterator end() const; + bool empty() const { return size() == 0; } + size_t size() const; + + /// Methods for supporting type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::Dictionary; + } +}; + +//===----------------------------------------------------------------------===// +// FloatAttr +//===----------------------------------------------------------------------===// + +class FloatAttr : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = APFloat; + + /// Return a float attribute for the specified value in the specified type. + /// These methods should only be used for simple constant values, e.g 1.0/2.0, + /// that are known-valid both as host double and the 'type' format. + static FloatAttr get(Type type, double value); + static FloatAttr getChecked(Type type, double value, Location loc); + + /// Return a float attribute for the specified value in the specified type. + static FloatAttr get(Type type, const APFloat &value); + static FloatAttr getChecked(Type type, const APFloat &value, Location loc); + + APFloat getValue() const; + + /// This function is used to convert the value to a double, even if it loses + /// precision. + double getValueAsDouble() const; + static double getValueAsDouble(APFloat val); + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::Float; + } + + /// Verify the construction invariants for a double value. + static LogicalResult verifyConstructionInvariants(Optional loc, + MLIRContext *ctx, Type type, + double value); + static LogicalResult verifyConstructionInvariants(Optional loc, + MLIRContext *ctx, Type type, + const APFloat &value); +}; + +//===----------------------------------------------------------------------===// +// IntegerAttr +//===----------------------------------------------------------------------===// + +class IntegerAttr + : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = APInt; + + static IntegerAttr get(Type type, int64_t value); + static IntegerAttr get(Type type, const APInt &value); + + APInt getValue() const; + // TODO(jpienaar): Change callers to use getValue instead. + int64_t getInt() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::Integer; + } +}; + +//===----------------------------------------------------------------------===// +// IntegerSetAttr +//===----------------------------------------------------------------------===// + +class IntegerSetAttr + : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = IntegerSet; + + static IntegerSetAttr get(IntegerSet value); + + IntegerSet getValue() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::IntegerSet; + } +}; + +//===----------------------------------------------------------------------===// +// OpaqueAttr +//===----------------------------------------------------------------------===// + +/// Opaque attributes represent attributes of non-registered dialects. These are +/// attribute represented in their raw string form, and can only usefully be +/// tested for attribute equality. +class OpaqueAttr : public Attribute::AttrBase { +public: + using Base::Base; + + /// Get or create a new OpaqueAttr with the provided dialect and string data. + static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type, + MLIRContext *context); + + /// Get or create a new OpaqueAttr with the provided dialect and string data. + /// If the given identifier is not a valid namespace for a dialect, then a + /// null attribute is returned. + static OpaqueAttr getChecked(Identifier dialect, StringRef attrData, + Type type, Location location); + + /// Returns the dialect namespace of the opaque attribute. + Identifier getDialectNamespace() const; + + /// Returns the raw attribute data of the opaque attribute. + StringRef getAttrData() const; + + /// Verify the construction of an opaque attribute. + static LogicalResult verifyConstructionInvariants(Optional loc, + MLIRContext *context, + Identifier dialect, + StringRef attrData, + Type type); + + static bool kindof(unsigned kind) { + return kind == StandardAttributes::Opaque; + } +}; + +//===----------------------------------------------------------------------===// +// StringAttr +//===----------------------------------------------------------------------===// + +class StringAttr : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = StringRef; + + /// Get an instance of a StringAttr with the given string. + static StringAttr get(StringRef bytes, MLIRContext *context); + + /// Get an instance of a StringAttr with the given string and Type. + static StringAttr get(StringRef bytes, Type type); + + StringRef getValue() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::String; + } +}; + +//===----------------------------------------------------------------------===// +// SymbolRefAttr +//===----------------------------------------------------------------------===// + +class FlatSymbolRefAttr; + +/// A symbol reference attribute represents a symbolic reference to another +/// operation. +class SymbolRefAttr + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Construct a symbol reference for the given value name. + static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); + + /// Construct a symbol reference for the given value name, and a set of nested + /// references that are further resolve to a nested symbol. + static SymbolRefAttr get(StringRef value, + ArrayRef references, + MLIRContext *ctx); + + /// Returns the name of the top level symbol reference, i.e. the root of the + /// reference path. + StringRef getRootReference() const; + + /// Returns the name of the fully resolved symbol, i.e. the leaf of the + /// reference path. + StringRef getLeafReference() const; + + /// Returns the set of nested references representing the path to the symbol + /// nested under the root reference. + ArrayRef getNestedReferences() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::SymbolRef; + } +}; + +/// A symbol reference with a reference path containing a single element. This +/// is used to refer to an operation within the current symbol table. +class FlatSymbolRefAttr : public SymbolRefAttr { +public: + using SymbolRefAttr::SymbolRefAttr; + using ValueType = StringRef; + + /// Construct a symbol reference for the given value name. + static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { + return SymbolRefAttr::get(value, ctx); + } + + /// Returns the name of the held symbol reference. + StringRef getValue() const { return getRootReference(); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(Attribute attr) { + SymbolRefAttr refAttr = attr.dyn_cast(); + return refAttr && refAttr.getNestedReferences().empty(); + } + +private: + using SymbolRefAttr::get; + using SymbolRefAttr::getNestedReferences; +}; + +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + +class TypeAttr : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = Type; + + static TypeAttr get(Type value); + + Type getValue() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; } +}; + +//===----------------------------------------------------------------------===// +// UnitAttr +//===----------------------------------------------------------------------===// + +/// Unit attributes are attributes that hold no specific value and are given +/// meaning by their existence. +class UnitAttr : public Attribute::AttrBase { +public: + using Base::Base; + + static UnitAttr get(MLIRContext *context); + + static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; } +}; + +//===----------------------------------------------------------------------===// +// Elements Attributes +//===----------------------------------------------------------------------===// + +namespace detail { +template class ElementsAttrIterator; +template class ElementsAttrRange; +} // namespace detail + +/// A base attribute that represents a reference to a static shaped tensor or +/// vector constant. +class ElementsAttr : public Attribute { +public: + using Attribute::Attribute; + template using iterator = detail::ElementsAttrIterator; + template using iterator_range = detail::ElementsAttrRange; + + /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor + /// with static shape. + ShapedType getType() const; + + /// Return the value at the given index. The index is expected to refer to a + /// valid element. + Attribute getValue(ArrayRef index) const; + + /// Return the value of type 'T' at the given index, where 'T' corresponds to + /// an Attribute type. + template T getValue(ArrayRef index) const { + return getValue(index).template cast(); + } + + /// Return the elements of this attribute as a value of type 'T'. Note: + /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support + /// iteration. + template iterator_range getValues() const; + + /// Return if the given 'index' refers to a valid element in this attribute. + bool isValidIndex(ArrayRef index) const; + + /// Returns the number of elements held by this attribute. + int64_t getNumElements() const; + + /// Generates a new ElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain integers. + ElementsAttr mapValues(Type newElementType, + function_ref mapping) const; + + /// Generates a new ElementsAttr by mapping each float value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain floats. + ElementsAttr mapValues(Type newElementType, + function_ref mapping) const; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr) { + return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR && + attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR; + } + +protected: + /// Returns the 1 dimensional flattened row-major index from the given + /// multi-dimensional index. + uint64_t getFlattenedIndex(ArrayRef index) const; +}; + +namespace detail { +/// DenseElementsAttr data is aligned to uint64_t, so this traits class is +/// necessary to interop with PointerIntPair. +class DenseElementDataPointerTypeTraits { +public: + static inline const void *getAsVoidPointer(const char *ptr) { return ptr; } + static inline const char *getFromVoidPointer(const void *ptr) { + return static_cast(ptr); + } + + // Note: We could steal more bits if the need arises. + enum { NumLowBitsAvailable = 1 }; +}; + +/// Pair of raw pointer and a boolean flag of whether the pointer holds a splat, +using DenseIterPtrAndSplat = + llvm::PointerIntPair; + +/// Impl iterator for indexed DenseElementAttr iterators that records a data +/// pointer and data index that is adjusted for the case of a splat attribute. +template +class DenseElementIndexedIteratorImpl + : public indexed_accessor_iterator { +protected: + DenseElementIndexedIteratorImpl(const char *data, bool isSplat, + size_t dataIndex) + : indexed_accessor_iterator({data, isSplat}, dataIndex) {} + + /// Return the current index for this iterator, adjusted for the case of a + /// splat. + ptrdiff_t getDataIndex() const { + bool isSplat = this->base.getInt(); + return isSplat ? 0 : this->index; + } + + /// Return the data base pointer. + const char *getData() const { return this->base.getPointer(); } +}; +} // namespace detail + +/// An attribute that represents a reference to a dense vector or tensor object. +/// +class DenseElementsAttr + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr) { + return attr.getKind() == StandardAttributes::DenseElements; + } + + /// Constructs a dense elements attribute from an array of element values. + /// Each element attribute value is expected to be an element of 'type'. + /// 'type' must be a vector or tensor with static shape. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + + /// Constructs a dense integer elements attribute from an array of integer + /// or floating-point values. Each value is expected to be the same bitwidth + /// of the element type of 'type'. 'type' must be a vector or tensor with + /// static shape. + template ::is_integer || + llvm::is_one_of::value>::type> + static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { + const char *data = reinterpret_cast(values.data()); + return getRawIntOrFloat( + type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), + /*isInt=*/std::numeric_limits::is_integer); + } + + /// Constructs a dense integer elements attribute from a single element. + template ::is_integer || + llvm::is_one_of::value>::type> + static DenseElementsAttr get(const ShapedType &type, T value) { + return get(type, llvm::makeArrayRef(value)); + } + + /// Overload of the above 'get' method that is specialized for boolean values. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + + /// Constructs a dense integer elements attribute from an array of APInt + /// values. Each APInt value is expected to have the same bitwidth as the + /// element type of 'type'. 'type' must be a vector or tensor with static + /// shape. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + + /// Constructs a dense float elements attribute from an array of APFloat + /// values. Each APFloat value is expected to have the same bitwidth as the + /// element type of 'type'. 'type' must be a vector or tensor with static + /// shape. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + + /// Construct a dense elements attribute for an initializer_list of values. + /// Each value is expected to be the same bitwidth of the element type of + /// 'type'. 'type' must be a vector or tensor with static shape. + template + static DenseElementsAttr get(const ShapedType &type, + const std::initializer_list &list) { + return get(type, ArrayRef(list)); + } + + //===--------------------------------------------------------------------===// + // Iterators + //===--------------------------------------------------------------------===// + + /// A utility iterator that allows walking over the internal Attribute values + /// of a DenseElementsAttr. + class AttributeElementIterator + : public indexed_accessor_iterator { + public: + /// Accesses the Attribute value at this iterator position. + Attribute operator*() const; + + private: + friend DenseElementsAttr; + + /// Constructs a new iterator. + AttributeElementIterator(DenseElementsAttr attr, size_t index); + }; + + /// Iterator for walking raw element values of the specified type 'T', which + /// may be any c++ data type matching the stored representation: int32_t, + /// float, etc. + template + class ElementIterator + : public detail::DenseElementIndexedIteratorImpl, + const T> { + public: + /// Accesses the raw value at this iterator position. + const T &operator*() const { + return reinterpret_cast(this->getData())[this->getDataIndex()]; + } + + private: + friend DenseElementsAttr; + + /// Constructs a new iterator. + ElementIterator(const char *data, bool isSplat, size_t dataIndex) + : detail::DenseElementIndexedIteratorImpl, const T>( + data, isSplat, dataIndex) {} + }; + + /// A utility iterator that allows walking over the internal bool values. + class BoolElementIterator + : public detail::DenseElementIndexedIteratorImpl { + public: + /// Accesses the bool value at this iterator position. + bool operator*() const; + + private: + friend DenseElementsAttr; + + /// Constructs a new iterator. + BoolElementIterator(DenseElementsAttr attr, size_t dataIndex); + }; + + /// A utility iterator that allows walking over the internal raw APInt values. + class IntElementIterator + : public detail::DenseElementIndexedIteratorImpl { + public: + /// Accesses the raw APInt value at this iterator position. + APInt operator*() const; + + private: + friend DenseElementsAttr; + + /// Constructs a new iterator. + IntElementIterator(DenseElementsAttr attr, size_t dataIndex); + + /// The bitwidth of the element type. + size_t bitWidth; + }; + + /// Iterator for walking over APFloat values. + class FloatElementIterator final + : public llvm::mapped_iterator> { + friend DenseElementsAttr; + + /// Initializes the float element iterator to the specified iterator. + FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it); + + public: + using reference = APFloat; + }; + + //===--------------------------------------------------------------------===// + // Value Querying + //===--------------------------------------------------------------------===// + + /// Returns if this attribute corresponds to a splat, i.e. if all element + /// values are the same. + bool isSplat() const; + + /// Return the splat value for this attribute. This asserts that the attribute + /// corresponds to a splat. + Attribute getSplatValue() const { return getSplatValue(); } + template + typename std::enable_if::value || + std::is_same::value, + T>::type + getSplatValue() const { + assert(isSplat() && "expected the attribute to be a splat"); + return *getValues().begin(); + } + /// Return the splat value for derived attribute element types. + template + typename std::enable_if::value && + !std::is_same::value, + T>::type + getSplatValue() const { + return getSplatValue().template cast(); + } + + /// Return the value at the given index. The 'index' is expected to refer to a + /// valid element. + Attribute getValue(ArrayRef index) const { + return getValue(index); + } + template T getValue(ArrayRef index) const { + // Skip to the element corresponding to the flattened index. + return *std::next(getValues().begin(), getFlattenedIndex(index)); + } + + /// Return the held element values as a range of integer or floating-point + /// values. + template ::value && + std::numeric_limits::is_integer) || + llvm::is_one_of::value>::type> + llvm::iterator_range> getValues() const { + assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer)); + auto rawData = getRawData().data(); + bool splat = isSplat(); + return {ElementIterator(rawData, splat, 0), + ElementIterator(rawData, splat, getNumElements())}; + } + + /// Return the held element values as a range of Attributes. + llvm::iterator_range getAttributeValues() const; + template ::value>::type> + llvm::iterator_range getValues() const { + return getAttributeValues(); + } + AttributeElementIterator attr_value_begin() const; + AttributeElementIterator attr_value_end() const; + + /// Return the held element values a range of T, where T is a derived + /// attribute type. + template + using DerivedAttributeElementIterator = + llvm::mapped_iterator; + template ::value && + !std::is_same::value>::type> + llvm::iterator_range> getValues() const { + auto castFn = [](Attribute attr) { return attr.template cast(); }; + return llvm::map_range(getAttributeValues(), + static_cast(castFn)); + } + + /// Return the held element values as a range of bool. The element type of + /// this attribute must be of integer type of bitwidth 1. + llvm::iterator_range getBoolValues() const; + template ::value>::type> + llvm::iterator_range getValues() const { + return getBoolValues(); + } + + /// Return the held element values as a range of APInts. The element type of + /// this attribute must be of integer type. + llvm::iterator_range getIntValues() const; + template ::value>::type> + llvm::iterator_range getValues() const { + return getIntValues(); + } + IntElementIterator int_value_begin() const; + IntElementIterator int_value_end() const; + + /// Return the held element values as a range of APFloat. The element type of + /// this attribute must be of float type. + llvm::iterator_range getFloatValues() const; + template ::value>::type> + llvm::iterator_range getValues() const { + return getFloatValues(); + } + FloatElementIterator float_value_begin() const; + FloatElementIterator float_value_end() const; + + //===--------------------------------------------------------------------===// + // Mutation Utilities + //===--------------------------------------------------------------------===// + + /// Return a new DenseElementsAttr that has the same data as the current + /// attribute, but has been reshaped to 'newType'. The new type must have the + /// same total number of elements as well as element type. + DenseElementsAttr reshape(ShapedType newType); + + /// Generates a new DenseElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This underlying type must be an DenseIntElementsAttr. + DenseElementsAttr mapValues(Type newElementType, + function_ref mapping) const; + + /// Generates a new DenseElementsAttr by mapping each float value to a new + /// underlying APInt. the new values can represent either a integer or float. + /// This underlying type must be an DenseFPElementsAttr. + DenseElementsAttr + mapValues(Type newElementType, + function_ref mapping) const; + +protected: + /// Return the raw storage data held by this attribute. + ArrayRef getRawData() const; + + /// Get iterators to the raw APInt values for each element in this attribute. + IntElementIterator raw_int_begin() const { + return IntElementIterator(*this, 0); + } + IntElementIterator raw_int_end() const { + return IntElementIterator(*this, getNumElements()); + } + + /// Constructs a dense elements attribute from an array of raw APInt values. + /// Each APInt value is expected to have the same bitwidth as the element type + /// of 'type'. 'type' must be a vector or tensor with static shape. + static DenseElementsAttr getRaw(ShapedType type, ArrayRef values); + + /// Get or create a new dense elements attribute instance with the given raw + /// data buffer. 'type' must be a vector or tensor with static shape. + static DenseElementsAttr getRaw(ShapedType type, ArrayRef data, + bool isSplat); + + /// Overload of the raw 'get' method that asserts that the given type is of + /// integer or floating-point type. This method is used to verify type + /// invariants that the templatized 'get' method cannot. + static DenseElementsAttr getRawIntOrFloat(ShapedType type, + ArrayRef data, + int64_t dataEltSize, bool isInt); + + /// Check the information for a c++ data type, check if this type is valid for + /// the current attribute. This method is used to verify specific type + /// invariants that the templatized 'getValues' method cannot. + bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const; +}; + +/// An attribute that represents a reference to a dense float vector or tensor +/// object. Each element is stored as a double. +class DenseFPElementsAttr : public DenseElementsAttr { +public: + using iterator = DenseElementsAttr::FloatElementIterator; + + using DenseElementsAttr::DenseElementsAttr; + + /// Get an instance of a DenseFPElementsAttr with the given arguments. This + /// simply wraps the DenseElementsAttr::get calls. + template + static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { + return DenseElementsAttr::get(type, llvm::makeArrayRef(arg)) + .template cast(); + } + template + static DenseFPElementsAttr get(const ShapedType &type, + const std::initializer_list &list) { + return DenseElementsAttr::get(type, list) + .template cast(); + } + + /// Generates a new DenseElementsAttr by mapping each value attribute, and + /// constructing the DenseElementsAttr given the new element type. + DenseElementsAttr + mapValues(Type newElementType, + function_ref mapping) const; + + /// Iterator access to the float element values. + iterator begin() const { return float_value_begin(); } + iterator end() const { return float_value_end(); } + + /// Method for supporting type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr); +}; + +/// An attribute that represents a reference to a dense integer vector or tensor +/// object. +class DenseIntElementsAttr : public DenseElementsAttr { +public: + /// DenseIntElementsAttr iterates on APInt, so we can use the raw element + /// iterator directly. + using iterator = DenseElementsAttr::IntElementIterator; + + using DenseElementsAttr::DenseElementsAttr; + + /// Get an instance of a DenseIntElementsAttr with the given arguments. This + /// simply wraps the DenseElementsAttr::get calls. + template + static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { + return DenseElementsAttr::get(type, llvm::makeArrayRef(arg)) + .template cast(); + } + template + static DenseIntElementsAttr get(const ShapedType &type, + const std::initializer_list &list) { + return DenseElementsAttr::get(type, list) + .template cast(); + } + + /// Generates a new DenseElementsAttr by mapping each value attribute, and + /// constructing the DenseElementsAttr given the new element type. + DenseElementsAttr mapValues(Type newElementType, + function_ref mapping) const; + + /// Iterator access to the integer element values. + iterator begin() const { return raw_int_begin(); } + iterator end() const { return raw_int_end(); } + + /// Method for supporting type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr); +}; + +/// An opaque attribute that represents a reference to a vector or tensor +/// constant with opaque content. This representation is for tensor constants +/// which the compiler may not need to interpret. This attribute is always +/// associated with a particular dialect, which provides a method to convert +/// tensor representation to a non-opaque format. +class OpaqueElementsAttr + : public Attribute::AttrBase { +public: + using Base::Base; + using ValueType = StringRef; + + static OpaqueElementsAttr get(Dialect *dialect, ShapedType type, + StringRef bytes); + + StringRef getValue() const; + + /// Return the value at the given index. The 'index' is expected to refer to a + /// valid element. + Attribute getValue(ArrayRef index) const; + + /// Decodes the attribute value using dialect-specific decoding hook. + /// Returns false if decoding is successful. If not, returns true and leaves + /// 'result' argument unspecified. + bool decode(ElementsAttr &result); + + /// Returns dialect associated with this opaque constant. + Dialect *getDialect() const; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::OpaqueElements; + } +}; + +/// An attribute that represents a reference to a sparse vector or tensor +/// object. +/// +/// This class uses COO (coordinate list) encoding to represent the sparse +/// elements in an element attribute. Specifically, the sparse vector/tensor +/// stores the indices and values as two separate dense elements attributes of +/// tensor type (even if the sparse attribute is of vector type, in order to +/// support empty lists). The dense elements attribute indices is a 2-D tensor +/// of 64-bit integer elements with shape [N, ndims], which specifies the +/// indices of the elements in the sparse tensor that contains nonzero values. +/// The dense elements attribute values is a 1-D tensor with shape [N], and it +/// supplies the corresponding values for the indices. +/// +/// For example, +/// `sparse, [[0, 0], [1, 2]], [1, 5]>` represents tensor +/// [[1, 0, 0, 0], +/// [0, 0, 5, 0], +/// [0, 0, 0, 0]]. +class SparseElementsAttr + : public Attribute::AttrBase { +public: + using Base::Base; + + template + using iterator = + llvm::mapped_iterator, + std::function>; + + /// 'type' must be a vector or tensor with static shape. + static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices, + DenseElementsAttr values); + + DenseIntElementsAttr getIndices() const; + + DenseElementsAttr getValues() const; + + /// Return the values of this attribute in the form of the given type 'T'. 'T' + /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc. + template llvm::iterator_range> getValues() const { + auto zeroValue = getZeroValue(); + auto valueIt = getValues().getValues().begin(); + const std::vector flatSparseIndices(getFlattenedSparseIndices()); + // TODO(riverriddle): Move-capture flatSparseIndices when c++14 is + // available. + std::function mapFn = [=](ptrdiff_t index) { + // Try to map the current index to one of the sparse indices. + for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) + if (flatSparseIndices[i] == index) + return *std::next(valueIt, i); + // Otherwise, return the zero value. + return zeroValue; + }; + return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); + } + + /// Return the value of the element at the given index. The 'index' is + /// expected to refer to a valid element. + Attribute getValue(ArrayRef index) const; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::SparseElements; + } + +private: + /// Get a zero APFloat for the given sparse attribute. + APFloat getZeroAPFloat() const; + + /// Get a zero APInt for the given sparse attribute. + APInt getZeroAPInt() const; + + /// Get a zero attribute for the given sparse attribute. + Attribute getZeroAttr() const; + + /// Utility methods to generate a zero value of some type 'T'. This is used by + /// the 'iterator' class. + /// Get a zero for a given attribute type. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAttr().template cast(); + } + /// Get a zero for an APInt. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAPInt(); + } + /// Get a zero for an APFloat. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAPFloat(); + } + /// Get a zero for an C++ integer or float type. + template + typename std::enable_if::is_integer || + llvm::is_one_of::value, + T>::type + getZeroValue() const { + return T(0); + } + + /// Flatten, and return, all of the sparse indices in this attribute in + /// row-major order. + std::vector getFlattenedSparseIndices() const; +}; + +/// An attribute that represents a reference to a splat vector or tensor +/// constant, meaning all of the elements have the same value. +class SplatElementsAttr : public DenseElementsAttr { +public: + using DenseElementsAttr::DenseElementsAttr; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr) { + auto denseAttr = attr.dyn_cast(); + return denseAttr && denseAttr.isSplat(); + } +}; + +namespace detail { +/// This class represents a general iterator over the values of an ElementsAttr. +/// It supports all subclasses aside from OpaqueElementsAttr. +template +class ElementsAttrIterator + : public llvm::iterator_facade_base, + std::random_access_iterator_tag, T, + std::ptrdiff_t, T, T> { + // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype' + // inside of a conversion operator. + using DenseIteratorT = typename std::enable_if< + true, + decltype(std::declval().getValues().begin())>::type; + using SparseIteratorT = SparseElementsAttr::iterator; + + /// A union containing the specific iterators for each derived attribute kind. + union Iterator { + Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {} + Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {} + Iterator() {} + ~Iterator() {} + + operator const DenseIteratorT &() const { return denseIt; } + operator const SparseIteratorT &() const { return sparseIt; } + operator DenseIteratorT &() { return denseIt; } + operator SparseIteratorT &() { return sparseIt; } + + /// An instance of a dense elements iterator. + DenseIteratorT denseIt; + /// An instance of a sparse elements iterator. + SparseIteratorT sparseIt; + }; + + /// Utility method to process a functor on each of the internal iterator + /// types. + template class ProcessFn, + typename... Args> + RetT process(Args &... args) const { + switch (attrKind) { + case StandardAttributes::DenseElements: + return ProcessFn()(args...); + case StandardAttributes::SparseElements: + return ProcessFn()(args...); + } + llvm_unreachable("unexpected attribute kind"); + } + + /// Utility functors used to generically implement the iterators methods. + template struct PlusAssign { + void operator()(ItT &it, ptrdiff_t offset) { it += offset; } + }; + template struct Minus { + ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } + }; + template struct MinusAssign { + void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } + }; + template struct Dereference { + T operator()(ItT &it) { return *it; } + }; + template struct ConstructIter { + void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); } + }; + template struct DestructIter { + void operator()(ItT &it) { it.~ItT(); } + }; + +public: + ElementsAttrIterator(const ElementsAttrIterator &rhs) + : attrKind(rhs.attrKind) { + process(it, rhs.it); + } + ~ElementsAttrIterator() { process(it); } + + /// Methods necessary to support random access iteration. + ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { + assert(attrKind == rhs.attrKind && "incompatible iterators"); + return process(it, rhs.it); + } + bool operator==(const ElementsAttrIterator &rhs) const { + return rhs.attrKind == attrKind && process(it, rhs.it); + } + bool operator<(const ElementsAttrIterator &rhs) const { + assert(attrKind == rhs.attrKind && "incompatible iterators"); + return process(it, rhs.it); + } + ElementsAttrIterator &operator+=(ptrdiff_t offset) { + process(it, offset); + return *this; + } + ElementsAttrIterator &operator-=(ptrdiff_t offset) { + process(it, offset); + return *this; + } + + /// Dereference the iterator at the current index. + T operator*() { return process(it); } + +private: + template + ElementsAttrIterator(unsigned attrKind, IteratorT &&it) + : attrKind(attrKind), it(std::forward(it)) {} + + /// Allow accessing the constructor. + friend ElementsAttr; + + /// The kind of derived elements attribute. + unsigned attrKind; + + /// A union containing the specific iterators for each derived kind. + Iterator it; +}; + +template +class ElementsAttrRange : public llvm::iterator_range> { + using llvm::iterator_range>::iterator_range; +}; +} // namespace detail + +/// Return the elements of this attribute as a value of type 'T'. +template +auto ElementsAttr::getValues() const -> iterator_range { + if (DenseElementsAttr denseAttr = dyn_cast()) { + auto values = denseAttr.getValues(); + return {iterator(getKind(), values.begin()), + iterator(getKind(), values.end())}; + } + if (SparseElementsAttr sparseAttr = dyn_cast()) { + auto values = sparseAttr.getValues(); + return {iterator(getKind(), values.begin()), + iterator(getKind(), values.end())}; + } + llvm_unreachable("unexpected attribute kind"); +} + +//===----------------------------------------------------------------------===// +// Attributes Utils +//===----------------------------------------------------------------------===// + +template bool Attribute::isa() const { + assert(impl && "isa<> used on a null attribute."); + return U::classof(*this); +} +template U Attribute::dyn_cast() const { + return isa() ? U(impl) : U(nullptr); +} +template U Attribute::dyn_cast_or_null() const { + return (impl && isa()) ? U(impl) : U(nullptr); +} +template U Attribute::cast() const { + assert(isa()); + return U(impl); +} + +// Make Attribute hashable. +inline ::llvm::hash_code hash_value(Attribute arg) { + return ::llvm::hash_value(arg.impl); +} + +//===----------------------------------------------------------------------===// +// NamedAttributeList +//===----------------------------------------------------------------------===// + +/// A NamedAttributeList is used to manage a list of named attributes. This +/// provides simple interfaces for adding/removing/finding attributes from +/// within a DictionaryAttr. +/// +/// We assume there will be relatively few attributes on a given operation +/// (maybe a dozen or so, but not hundreds or thousands) so we use linear +/// searches for everything. +class NamedAttributeList { +public: + NamedAttributeList(DictionaryAttr attrs = nullptr) + : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {} + NamedAttributeList(ArrayRef attributes); + + bool operator!=(const NamedAttributeList &other) const { + return !(*this == other); + } + bool operator==(const NamedAttributeList &other) const { + return attrs == other.attrs; + } + + /// Return the underlying dictionary attribute. This may be null, if this list + /// has no attributes. + DictionaryAttr getDictionary() const { return attrs; } + + /// Return all of the attributes on this operation. + ArrayRef getAttrs() const; + + /// Replace the held attributes with ones provided in 'newAttrs'. + void setAttrs(ArrayRef attributes); + + /// Return the specified attribute if present, null otherwise. + Attribute get(StringRef name) const; + Attribute get(Identifier name) const; + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void set(Identifier name, Attribute value); + + enum class RemoveResult { Removed, NotFound }; + + /// Remove the attribute with the specified name if it exists. The return + /// value indicates whether the attribute was present or not. + RemoveResult remove(Identifier name); + +private: + DictionaryAttr attrs; +}; + +} // end namespace mlir. + +namespace llvm { + +// Attribute hash just like pointers. +template <> struct DenseMapInfo { + static mlir::Attribute getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::Attribute(static_cast(pointer)); + } + static mlir::Attribute getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::Attribute(static_cast(pointer)); + } + static unsigned getHashValue(mlir::Attribute val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) { + return LHS == RHS; + } +}; + +/// Allow LLVM to steal the low bits of Attributes. +template <> struct PointerLikeTypeTraits { + static inline void *getAsVoidPointer(mlir::Attribute attr) { + return const_cast(attr.getAsOpaquePointer()); + } + static inline mlir::Attribute getFromVoidPointer(void *ptr) { + return mlir::Attribute::getFromOpaquePointer(ptr); + } + enum { NumLowBitsAvailable = 3 }; +}; + +template <> +struct PointerLikeTypeTraits + : public PointerLikeTypeTraits { + static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) { + return PointerLikeTypeTraits::getFromVoidPointer(ptr) + .cast(); + } +}; + +} // namespace llvm + +#endif diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h new file mode 100644 index 0000000000000000000000000000000000000000..934eed93c3b3327bfebf015a6975480aab67fb10 --- /dev/null +++ b/mlir/include/mlir/IR/Block.h @@ -0,0 +1,335 @@ +//===- Block.h - MLIR Block Class -------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Block class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BLOCK_H +#define MLIR_IR_BLOCK_H + +#include "mlir/IR/BlockSupport.h" +#include "mlir/IR/Visitors.h" + +namespace mlir { +/// `Block` represents an ordered list of `Operation`s. +class Block : public IRObjectWithUseList, + public llvm::ilist_node_with_parent { +public: + explicit Block() {} + ~Block(); + + void clear() { + // Drop all references from within this block. + dropAllReferences(); + + // Clear operations in the reverse order so that uses are destroyed + // before their defs. + while (!empty()) + operations.pop_back(); + } + + /// Provide a 'getParent' method for ilist_node_with_parent methods. + /// We mark it as a const function because ilist_node_with_parent specifically + /// requires a 'getParent() const' method. Once ilist_node removes this + /// constraint, we should drop the const to fit the rest of the MLIR const + /// model. + Region *getParent() const; + + /// Returns the closest surrounding operation that contains this block. + Operation *getParentOp(); + + /// Return if this block is the entry block in the parent region. + bool isEntryBlock(); + + /// Insert this block (which must not already be in a region) right before + /// the specified block. + void insertBefore(Block *block); + + /// Unlink this block from its current region and insert it right before the + /// specific block. + void moveBefore(Block *block); + + /// Unlink this Block from its parent region and delete it. + void erase(); + + //===--------------------------------------------------------------------===// + // Block argument management + //===--------------------------------------------------------------------===// + + // This is the list of arguments to the block. + using BlockArgListType = MutableArrayRef; + + BlockArgListType getArguments() { return arguments; } + + using args_iterator = BlockArgListType::iterator; + using reverse_args_iterator = BlockArgListType::reverse_iterator; + args_iterator args_begin() { return getArguments().begin(); } + args_iterator args_end() { return getArguments().end(); } + reverse_args_iterator args_rbegin() { return getArguments().rbegin(); } + reverse_args_iterator args_rend() { return getArguments().rend(); } + + bool args_empty() { return arguments.empty(); } + + /// Add one value to the argument list. + BlockArgument addArgument(Type type); + + /// Add one argument to the argument list for each type specified in the list. + iterator_range addArguments(ArrayRef types); + + /// Erase the argument at 'index' and remove it from the argument list. If + /// 'updatePredTerms' is set to true, this argument is also removed from the + /// terminators of each predecessor to this block. + void eraseArgument(unsigned index, bool updatePredTerms = true); + + unsigned getNumArguments() { return arguments.size(); } + BlockArgument getArgument(unsigned i) { return arguments[i]; } + + //===--------------------------------------------------------------------===// + // Operation list management + //===--------------------------------------------------------------------===// + + /// This is the list of operations in the block. + using OpListType = llvm::iplist; + OpListType &getOperations() { return operations; } + + // Iteration over the operations in the block. + using iterator = OpListType::iterator; + using reverse_iterator = OpListType::reverse_iterator; + + iterator begin() { return operations.begin(); } + iterator end() { return operations.end(); } + reverse_iterator rbegin() { return operations.rbegin(); } + reverse_iterator rend() { return operations.rend(); } + + bool empty() { return operations.empty(); } + void push_back(Operation *op) { operations.push_back(op); } + void push_front(Operation *op) { operations.push_front(op); } + + Operation &back() { return operations.back(); } + Operation &front() { return operations.front(); } + + /// Returns 'op' if 'op' lies in this block, or otherwise finds the + /// ancestor operation of 'op' that lies in this block. Returns nullptr if + /// the latter fails. + /// TODO: This is very specific functionality that should live somewhere else, + /// probably in Dominance.cpp. + Operation *findAncestorOpInBlock(Operation &op); + + /// This drops all operand uses from operations within this block, which is + /// an essential step in breaking cyclic dependences between references when + /// they are to be deleted. + void dropAllReferences(); + + /// This drops all uses of values defined in this block or in the blocks of + /// nested regions wherever the uses are located. + void dropAllDefinedValueUses(); + + /// Returns true if the ordering of the child operations is valid, false + /// otherwise. + bool isOpOrderValid(); + + /// Invalidates the current ordering of operations. + void invalidateOpOrder(); + + /// Verifies the current ordering of child operations matches the + /// validOpOrder flag. Returns false if the order is valid, true otherwise. + bool verifyOpOrder(); + + /// Recomputes the ordering of child operations within the block. + void recomputeOpOrder(); + +private: + /// A utility iterator that filters out operations that are not 'OpT'. + template + class op_filter_iterator + : public llvm::filter_iterator { + static bool filter(Operation &op) { return llvm::isa(op); } + + public: + op_filter_iterator(Block::iterator it, Block::iterator end) + : llvm::filter_iterator( + it, end, &filter) {} + + /// Allow implicit conversion to the underlying block iterator. + operator Block::iterator() const { return this->wrapped(); } + }; + +public: + /// This class provides iteration over the held operations of a block for a + /// specific operation type. + template + class op_iterator : public llvm::mapped_iterator, + OpT (*)(Operation &)> { + static OpT unwrap(Operation &op) { return cast(op); } + + public: + using reference = OpT; + + /// Initializes the iterator to the specified filter iterator. + op_iterator(op_filter_iterator it) + : llvm::mapped_iterator, OpT (*)(Operation &)>( + it, &unwrap) {} + + /// Allow implicit conversion to the underlying block iterator. + operator Block::iterator() const { return this->wrapped(); } + }; + + /// Return an iterator range over the operations within this block that are of + /// 'OpT'. + template iterator_range> getOps() { + auto endIt = end(); + return {op_filter_iterator(begin(), endIt), + op_filter_iterator(endIt, endIt)}; + } + template op_iterator op_begin() { + return op_filter_iterator(begin(), end()); + } + template op_iterator op_end() { + return op_filter_iterator(end(), end()); + } + + /// Return an iterator range over the operation within this block excluding + /// the terminator operation at the end. + iterator_range without_terminator() { + if (begin() == end()) + return {begin(), end()}; + auto endIt = --end(); + return {begin(), endIt}; + } + + //===--------------------------------------------------------------------===// + // Terminator management + //===--------------------------------------------------------------------===// + + /// Get the terminator operation of this block. This function asserts that + /// the block has a valid terminator operation. + Operation *getTerminator(); + + //===--------------------------------------------------------------------===// + // Predecessors and successors. + //===--------------------------------------------------------------------===// + + // Predecessor iteration. + using pred_iterator = PredecessorIterator; + pred_iterator pred_begin() { + return pred_iterator((BlockOperand *)getFirstUse()); + } + pred_iterator pred_end() { return pred_iterator(nullptr); } + iterator_range getPredecessors() { + return {pred_begin(), pred_end()}; + } + + /// Return true if this block has no predecessors. + bool hasNoPredecessors(); + + /// If this block has exactly one predecessor, return it. Otherwise, return + /// null. + /// + /// Note that if a block has duplicate predecessors from a single block (e.g. + /// if you have a conditional branch with the same block as the true/false + /// destinations) is not considered to be a single predecessor. + Block *getSinglePredecessor(); + + // Indexed successor access. + unsigned getNumSuccessors(); + Block *getSuccessor(unsigned i); + + // Successor iteration. + using succ_iterator = SuccessorRange::iterator; + succ_iterator succ_begin() { return getSuccessors().begin(); } + succ_iterator succ_end() { return getSuccessors().end(); } + SuccessorRange getSuccessors() { return SuccessorRange(this); } + + //===--------------------------------------------------------------------===// + // Operation Walkers + //===--------------------------------------------------------------------===// + + /// Walk the operations in this block in postorder, calling the callback for + /// each operation. + /// See Operation::walk for more details. + template > + RetT walk(FnT &&callback) { + return walk(begin(), end(), std::forward(callback)); + } + + /// Walk the operations in the specified [begin, end) range of this block in + /// postorder, calling the callback for each operation. This method is invoked + /// for void return callbacks. + /// See Operation::walk for more details. + template > + typename std::enable_if::value, RetT>::type + walk(Block::iterator begin, Block::iterator end, FnT &&callback) { + for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) + detail::walkOperations(&op, callback); + } + + /// Walk the operations in the specified [begin, end) range of this block in + /// postorder, calling the callback for each operation. This method is invoked + /// for interruptible callbacks. + /// See Operation::walk for more details. + template > + typename std::enable_if::value, RetT>::type + walk(Block::iterator begin, Block::iterator end, FnT &&callback) { + for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) + if (detail::walkOperations(&op, callback).wasInterrupted()) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + /// Split the block into two blocks before the specified operation or + /// iterator. + /// + /// Note that all operations BEFORE the specified iterator stay as part of + /// the original basic block, and the rest of the operations in the original + /// block are moved to the new block, including the old terminator. The + /// original block is left without a terminator. + /// + /// The newly formed Block is returned, and the specified iterator is + /// invalidated. + Block *splitBlock(iterator splitBefore); + Block *splitBlock(Operation *splitBeforeOp) { + return splitBlock(iterator(splitBeforeOp)); + } + + /// Returns pointer to member of operation list. + static OpListType Block::*getSublistAccess(Operation *) { + return &Block::operations; + } + + void print(raw_ostream &os); + void dump(); + + /// Print out the name of the block without printing its body. + /// NOTE: The printType argument is ignored. We keep it for compatibility + /// with LLVM dominator machinery that expects it to exist. + void printAsOperand(raw_ostream &os, bool printType = true); + +private: + /// Pair of the parent object that owns this block and a bit that signifies if + /// the operations within this block have a valid ordering. + llvm::PointerIntPair parentValidOpOrderPair; + + /// This is the list of operations in the block. + OpListType operations; + + /// This is the list of arguments to the block. + std::vector arguments; + + Block(Block &) = delete; + void operator=(Block &) = delete; + + friend struct llvm::ilist_traits; +}; +} // end namespace mlir + +#endif // MLIR_IR_BLOCK_H diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h new file mode 100644 index 0000000000000000000000000000000000000000..b7ad36072bd1d6aa488e20f008ff6a600f5f8e0f --- /dev/null +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -0,0 +1,88 @@ +//===- BlockAndValueMapping.h -----------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a utility class for maintaining a mapping for multiple +// value types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BLOCKANDVALUEMAPPING_H +#define MLIR_IR_BLOCKANDVALUEMAPPING_H + +#include "mlir/IR/Block.h" + +namespace mlir { +// This is a utility class for mapping one set of values to another. New +// mappings can be inserted via 'map'. Existing mappings can be +// found via the 'lookup*' functions. There are two variants that differ only in +// return value when an existing is not found for the provided key. +// 'lookupOrNull' returns nullptr where as 'lookupOrDefault' will return the +// lookup key. +class BlockAndValueMapping { +public: + /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, + /// it is overwritten. + void map(Block *from, Block *to) { valueMap[from] = to; } + void map(Value from, Value to) { + valueMap[from.getAsOpaquePointer()] = to.getAsOpaquePointer(); + } + + /// Erases a mapping for 'from'. + void erase(Block *from) { valueMap.erase(from); } + void erase(Value from) { valueMap.erase(from.getAsOpaquePointer()); } + + /// Checks to see if a mapping for 'from' exists. + bool contains(Block *from) const { return valueMap.count(from); } + bool contains(Value from) const { + return valueMap.count(from.getAsOpaquePointer()); + } + + /// Lookup a mapped value within the map. If a mapping for the provided value + /// does not exist then return nullptr. + Block *lookupOrNull(Block *from) const { + return lookupOrValue(from, (Block *)nullptr); + } + Value lookupOrNull(Value from) const { return lookupOrValue(from, Value()); } + + /// Lookup a mapped value within the map. If a mapping for the provided value + /// does not exist then return the provided value. + Block *lookupOrDefault(Block *from) const { + return lookupOrValue(from, from); + } + Value lookupOrDefault(Value from) const { return lookupOrValue(from, from); } + + /// Lookup a mapped value within the map. This asserts the provided value + /// exists within the map. + template T lookup(T from) const { + auto result = lookupOrNull(from); + assert(result && "expected 'from' to be contained within the map"); + return result; + } + + /// Clears all mappings held by the mapper. + void clear() { valueMap.clear(); } + +private: + /// Utility lookupOrValue that looks up an existing key or returns the + /// provided value. + Block *lookupOrValue(Block *from, Block *value) const { + auto it = valueMap.find(from); + return it != valueMap.end() ? reinterpret_cast(it->second) : value; + } + Value lookupOrValue(Value from, Value value) const { + auto it = valueMap.find(from.getAsOpaquePointer()); + return it != valueMap.end() ? Value::getFromOpaquePointer(it->second) + : value; + } + + DenseMap valueMap; +}; + +} // end namespace mlir + +#endif // MLIR_IR_BLOCKANDVALUEMAPPING_H diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h new file mode 100644 index 0000000000000000000000000000000000000000..bc6a8245c45c3b2a3f6d7a53f178bcb59f8e17c7 --- /dev/null +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -0,0 +1,144 @@ +//===- BlockSupport.h -------------------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a number of support types for the Block class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BLOCK_SUPPORT_H +#define MLIR_IR_BLOCK_SUPPORT_H + +#include "mlir/IR/Value.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/ilist.h" +#include "llvm/ADT/ilist_node.h" + +namespace mlir { +class Block; + +using BlockOperand = IROperandImpl; + +//===----------------------------------------------------------------------===// +// Predecessors +//===----------------------------------------------------------------------===// + +/// Implement a predecessor iterator for blocks. This works by walking the use +/// lists of the blocks. The entries on this list are the BlockOperands that +/// are embedded into terminator operations. From the operand, we can get the +/// terminator that contains it, and its parent block is the predecessor. +class PredecessorIterator final + : public llvm::mapped_iterator, + Block *(*)(BlockOperand &)> { + static Block *unwrap(BlockOperand &value); + +public: + using reference = Block *; + + /// Initializes the operand type iterator to the specified operand iterator. + PredecessorIterator(ValueUseIterator it) + : llvm::mapped_iterator, + Block *(*)(BlockOperand &)>(it, &unwrap) {} + explicit PredecessorIterator(BlockOperand *operand) + : PredecessorIterator(ValueUseIterator(operand)) {} + + /// Get the successor number in the predecessor terminator. + unsigned getSuccessorIndex() const; +}; + +//===----------------------------------------------------------------------===// +// Successors +//===----------------------------------------------------------------------===// + +/// This class implements the successor iterators for Block. +class SuccessorRange final + : public detail::indexed_accessor_range_base { +public: + using RangeBaseT::RangeBaseT; + SuccessorRange(Block *block); + SuccessorRange(Operation *term); + +private: + /// See `detail::indexed_accessor_range_base` for details. + static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) { + return object + index; + } + /// See `detail::indexed_accessor_range_base` for details. + static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) { + return object[index].get(); + } + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; + +} // end namespace mlir + +namespace llvm { + +//===----------------------------------------------------------------------===// +// ilist_traits for Operation +//===----------------------------------------------------------------------===// + +namespace ilist_detail { +// Explicitly define the node access for the operation list so that we can +// break the dependence on the Operation class in this header. This allows for +// operations to have trailing Regions without a circular include +// dependence. +template <> +struct SpecificNodeAccess< + typename compute_node_options<::mlir::Operation>::type> : NodeAccess { +protected: + using OptionsT = typename compute_node_options::type; + using pointer = typename OptionsT::pointer; + using const_pointer = typename OptionsT::const_pointer; + using node_type = ilist_node_impl; + + static node_type *getNodePtr(pointer N); + static const node_type *getNodePtr(const_pointer N); + + static pointer getValuePtr(node_type *N); + static const_pointer getValuePtr(const node_type *N); +}; +} // end namespace ilist_detail + +template <> struct ilist_traits<::mlir::Operation> { + using Operation = ::mlir::Operation; + using op_iterator = simple_ilist::iterator; + + static void deleteNode(Operation *op); + void addNodeToList(Operation *op); + void removeNodeFromList(Operation *op); + void transferNodesFromList(ilist_traits &otherList, + op_iterator first, op_iterator last); + +private: + mlir::Block *getContainingBlock(); +}; + +//===----------------------------------------------------------------------===// +// ilist_traits for Block +//===----------------------------------------------------------------------===// + +template <> +struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> { + using Block = ::mlir::Block; + using block_iterator = simple_ilist<::mlir::Block>::iterator; + + void addNodeToList(Block *block); + void removeNodeFromList(Block *block); + void transferNodesFromList(ilist_traits &otherList, + block_iterator first, block_iterator last); + +private: + mlir::Region *getParentRegion(); +}; + +} // end namespace llvm + +#endif // MLIR_IR_BLOCK_SUPPORT_H diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h new file mode 100644 index 0000000000000000000000000000000000000000..2db44cbfa2e71f91a56ce92727d0472bb0e6dbbd --- /dev/null +++ b/mlir/include/mlir/IR/Builders.h @@ -0,0 +1,381 @@ +//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BUILDERS_H +#define MLIR_IR_BUILDERS_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +class AffineExpr; +class BlockAndValueMapping; +class ModuleOp; +class UnknownLoc; +class FileLineColLoc; +class Type; +class PrimitiveType; +class IntegerType; +class FunctionType; +class MemRefType; +class VectorType; +class RankedTensorType; +class UnrankedTensorType; +class TupleType; +class NoneType; +class BoolAttr; +class IntegerAttr; +class FloatAttr; +class StringAttr; +class TypeAttr; +class ArrayAttr; +class SymbolRefAttr; +class ElementsAttr; +class DenseElementsAttr; +class DenseIntElementsAttr; +class AffineMapAttr; +class AffineMap; +class UnitAttr; + +/// This class is a general helper class for creating context-global objects +/// like types, attributes, and affine expressions. +class Builder { +public: + explicit Builder(MLIRContext *context) : context(context) {} + explicit Builder(ModuleOp module); + + MLIRContext *getContext() const { return context; } + + Identifier getIdentifier(StringRef str); + + // Locations. + Location getUnknownLoc(); + Location getFileLineColLoc(Identifier filename, unsigned line, + unsigned column); + Location getFusedLoc(ArrayRef locs, + Attribute metadata = Attribute()); + + // Types. + FloatType getBF16Type(); + FloatType getF16Type(); + FloatType getF32Type(); + FloatType getF64Type(); + + IndexType getIndexType(); + + IntegerType getI1Type(); + IntegerType getIntegerType(unsigned width); + FunctionType getFunctionType(ArrayRef inputs, ArrayRef results); + TupleType getTupleType(ArrayRef elementTypes); + NoneType getNoneType(); + + /// Get or construct an instance of the type 'ty' with provided arguments. + template Ty getType(Args... args) { + return Ty::get(context, args...); + } + + // Attributes. + NamedAttribute getNamedAttr(StringRef name, Attribute val); + + UnitAttr getUnitAttr(); + BoolAttr getBoolAttr(bool value); + DictionaryAttr getDictionaryAttr(ArrayRef value); + IntegerAttr getIntegerAttr(Type type, int64_t value); + IntegerAttr getIntegerAttr(Type type, const APInt &value); + FloatAttr getFloatAttr(Type type, double value); + FloatAttr getFloatAttr(Type type, const APFloat &value); + StringAttr getStringAttr(StringRef bytes); + ArrayAttr getArrayAttr(ArrayRef value); + FlatSymbolRefAttr getSymbolRefAttr(Operation *value); + FlatSymbolRefAttr getSymbolRefAttr(StringRef value); + SymbolRefAttr getSymbolRefAttr(StringRef value, + ArrayRef nestedReferences); + + // Returns a 0-valued attribute of the given `type`. This function only + // supports boolean, integer, and 16-/32-/64-bit float types, and vector or + // ranked tensor of them. Returns null attribute otherwise. + Attribute getZeroAttr(Type type); + + // Convenience methods for fixed types. + FloatAttr getF16FloatAttr(float value); + FloatAttr getF32FloatAttr(float value); + FloatAttr getF64FloatAttr(double value); + + IntegerAttr getI8IntegerAttr(int8_t value); + IntegerAttr getI16IntegerAttr(int16_t value); + IntegerAttr getI32IntegerAttr(int32_t value); + IntegerAttr getI64IntegerAttr(int64_t value); + + DenseIntElementsAttr getI32VectorAttr(ArrayRef values); + + ArrayAttr getAffineMapArrayAttr(ArrayRef values); + ArrayAttr getI32ArrayAttr(ArrayRef values); + ArrayAttr getI64ArrayAttr(ArrayRef values); + ArrayAttr getIndexArrayAttr(ArrayRef values); + ArrayAttr getF32ArrayAttr(ArrayRef values); + ArrayAttr getF64ArrayAttr(ArrayRef values); + ArrayAttr getStrArrayAttr(ArrayRef values); + + // Affine expressions and affine maps. + AffineExpr getAffineDimExpr(unsigned position); + AffineExpr getAffineSymbolExpr(unsigned position); + AffineExpr getAffineConstantExpr(int64_t constant); + + // Special cases of affine maps and integer sets + /// Returns a zero result affine map with no dimensions or symbols: () -> (). + AffineMap getEmptyAffineMap(); + /// Returns a single constant result affine map with 0 dimensions and 0 + /// symbols. One constant result: () -> (val). + AffineMap getConstantAffineMap(int64_t val); + // One dimension id identity map: (i) -> (i). + AffineMap getDimIdentityMap(); + // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2). + AffineMap getMultiDimIdentityMap(unsigned rank); + // One symbol identity map: ()[s] -> (s). + AffineMap getSymbolIdentityMap(); + + /// Returns a map that shifts its (single) input dimension by 'shift'. + /// (d0) -> (d0 + shift) + AffineMap getSingleDimShiftAffineMap(int64_t shift); + + /// Returns an affine map that is a translation (shift) of all result + /// expressions in 'map' by 'shift'. + /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2 + /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) + AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); + +protected: + MLIRContext *context; +}; + +/// This class helps build Operations. Operations that are created are +/// automatically inserted at an insertion point. The builder is copyable. +class OpBuilder : public Builder { +public: + /// Create a builder with the given context. + explicit OpBuilder(MLIRContext *ctx) : Builder(ctx) {} + + /// Create a builder and set the insertion point to the start of the region. + explicit OpBuilder(Region *region) : Builder(region->getContext()) { + if (!region->empty()) + setInsertionPoint(®ion->front(), region->front().begin()); + } + explicit OpBuilder(Region ®ion) : OpBuilder(®ion) {} + + virtual ~OpBuilder(); + + /// Create a builder and set insertion point to the given operation, which + /// will cause subsequent insertions to go right before it. + explicit OpBuilder(Operation *op) : Builder(op->getContext()) { + setInsertionPoint(op); + } + + explicit OpBuilder(Block *block) : OpBuilder(block, block->end()) {} + + OpBuilder(Block *block, Block::iterator insertPoint) + : OpBuilder(block->getParent()) { + setInsertionPoint(block, insertPoint); + } + + /// This class represents a saved insertion point. + class InsertPoint { + public: + /// Creates a new insertion point which doesn't point to anything. + InsertPoint() = default; + + /// Creates a new insertion point at the given location. + InsertPoint(Block *insertBlock, Block::iterator insertPt) + : block(insertBlock), point(insertPt) {} + + /// Returns true if this insert point is set. + bool isSet() const { return (block != nullptr); } + + Block *getBlock() const { return block; } + Block::iterator getPoint() const { return point; } + + private: + Block *block = nullptr; + Block::iterator point; + }; + + /// RAII guard to reset the insertion point of the builder when destroyed. + class InsertionGuard { + public: + InsertionGuard(OpBuilder &builder) + : builder(builder), ip(builder.saveInsertionPoint()) {} + ~InsertionGuard() { builder.restoreInsertionPoint(ip); } + + private: + OpBuilder &builder; + OpBuilder::InsertPoint ip; + }; + + /// Reset the insertion point to no location. Creating an operation without a + /// set insertion point is an error, but this can still be useful when the + /// current insertion point a builder refers to is being removed. + void clearInsertionPoint() { + this->block = nullptr; + insertPoint = Block::iterator(); + } + + /// Return a saved insertion point. + InsertPoint saveInsertionPoint() const { + return InsertPoint(getInsertionBlock(), getInsertionPoint()); + } + + /// Restore the insert point to a previously saved point. + void restoreInsertionPoint(InsertPoint ip) { + if (ip.isSet()) + setInsertionPoint(ip.getBlock(), ip.getPoint()); + else + clearInsertionPoint(); + } + + /// Set the insertion point to the specified location. + void setInsertionPoint(Block *block, Block::iterator insertPoint) { + // TODO: check that insertPoint is in this rather than some other block. + this->block = block; + this->insertPoint = insertPoint; + } + + /// Sets the insertion point to the specified operation, which will cause + /// subsequent insertions to go right before it. + void setInsertionPoint(Operation *op) { + setInsertionPoint(op->getBlock(), Block::iterator(op)); + } + + /// Sets the insertion point to the node after the specified operation, which + /// will cause subsequent insertions to go right after it. + void setInsertionPointAfter(Operation *op) { + setInsertionPoint(op->getBlock(), ++Block::iterator(op)); + } + + /// Sets the insertion point to the start of the specified block. + void setInsertionPointToStart(Block *block) { + setInsertionPoint(block, block->begin()); + } + + /// Sets the insertion point to the end of the specified block. + void setInsertionPointToEnd(Block *block) { + setInsertionPoint(block, block->end()); + } + + /// Return the block the current insertion point belongs to. Note that the + /// the insertion point is not necessarily the end of the block. + Block *getInsertionBlock() const { return block; } + + /// Returns the current insertion point of the builder. + Block::iterator getInsertionPoint() const { return insertPoint; } + + /// Insert the given operation at the current insertion point and return it. + virtual Operation *insert(Operation *op); + + /// Add new block and set the insertion point to the end of it. The block is + /// inserted at the provided insertion point of 'parent'. + Block *createBlock(Region *parent, Region::iterator insertPt = {}); + + /// Add new block and set the insertion point to the end of it. The block is + /// placed before 'insertBefore'. + Block *createBlock(Block *insertBefore); + + /// Returns the current block of the builder. + Block *getBlock() const { return block; } + + /// Creates an operation given the fields represented as an OperationState. + Operation *createOperation(const OperationState &state); + + /// Create an operation of specific op type at the current insertion point. + template + OpTy create(Location location, Args &&... args) { + OperationState state(location, OpTy::getOperationName()); + OpTy::build(this, state, std::forward(args)...); + auto *op = createOperation(state); + auto result = dyn_cast(op); + assert(result && "Builder didn't return the right type"); + return result; + } + + /// Create an operation of specific op type at the current insertion point, + /// and immediately try to fold it. This functions populates 'results' with + /// the results after folding the operation. + template + void createOrFold(SmallVectorImpl &results, Location location, + Args &&... args) { + // Create the operation without using 'createOperation' as we don't want to + // insert it yet. + OperationState state(location, OpTy::getOperationName()); + OpTy::build(this, state, std::forward(args)...); + Operation *op = Operation::create(state); + + // Fold the operation. If successful destroy it, otherwise insert it. + if (succeeded(tryFold(op, results))) + op->destroy(); + else + insert(op); + } + + /// Overload to create or fold a single result operation. + template + typename std::enable_if(), + Value>::type + createOrFold(Location location, Args &&... args) { + SmallVector results; + createOrFold(results, location, std::forward(args)...); + return results.front(); + } + + /// Overload to create or fold a zero result operation. + template + typename std::enable_if(), + OpTy>::type + createOrFold(Location location, Args &&... args) { + auto op = create(location, std::forward(args)...); + SmallVector unused; + tryFold(op.getOperation(), unused); + + // Folding cannot remove a zero-result operation, so for convenience we + // continue to return it. + return op; + } + + /// Attempts to fold the given operation and places new results within + /// 'results'. Returns success if the operation was folded, failure otherwise. + /// Note: This function does not erase the operation on a successful fold. + LogicalResult tryFold(Operation *op, SmallVectorImpl &results); + + /// Creates a deep copy of the specified operation, remapping any operands + /// that use values outside of the operation using the map that is provided + /// ( leaving them alone if no entry is present). Replaces references to + /// cloned sub-operations to the corresponding operation that is copied, + /// and adds those mappings to the map. + Operation *clone(Operation &op, BlockAndValueMapping &mapper) { + return insert(op.clone(mapper)); + } + Operation *clone(Operation &op) { return insert(op.clone()); } + + /// Creates a deep copy of this operation but keep the operation regions + /// empty. Operands are remapped using `mapper` (if present), and `mapper` is + /// updated to contain the results. + Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) { + return insert(op.cloneWithoutRegions(mapper)); + } + Operation *cloneWithoutRegions(Operation &op) { + return insert(op.cloneWithoutRegions()); + } + template OpT cloneWithoutRegions(OpT op) { + return cast(cloneWithoutRegions(*op.getOperation())); + } + +private: + Block *block = nullptr; + Block::iterator insertPoint; +}; + +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..555b16fd29d0386fbfde1187a4f229c04fea2a6d --- /dev/null +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td) +mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIROpAsmInterfacesIncGen) diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h new file mode 100644 index 0000000000000000000000000000000000000000..e3d0f8382083332745ab6aa54caa84e904922448 --- /dev/null +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -0,0 +1,649 @@ +//===- Diagnostics.h - MLIR Diagnostics -------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities for emitting diagnostics. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIAGNOSTICS_H +#define MLIR_IR_DIAGNOSTICS_H + +#include "mlir/IR/Location.h" +#include "mlir/Support/STLExtras.h" +#include + +namespace llvm { +class MemoryBuffer; +class SMLoc; +class SourceMgr; +} // end namespace llvm + +namespace mlir { +class DiagnosticEngine; +class Identifier; +struct LogicalResult; +class MLIRContext; +class Operation; +class OperationName; +class Type; + +namespace detail { +struct DiagnosticEngineImpl; +} // end namespace detail + +/// Defines the different supported severity of a diagnostic. +enum class DiagnosticSeverity { + Note, + Warning, + Error, + Remark, +}; + +//===----------------------------------------------------------------------===// +// DiagnosticArgument +//===----------------------------------------------------------------------===// + +/// A variant type that holds a single argument for a diagnostic. +class DiagnosticArgument { +public: + /// Enum that represents the different kinds of diagnostic arguments + /// supported. + enum class DiagnosticArgumentKind { + Attribute, + Double, + Integer, + Operation, + String, + Type, + Unsigned, + }; + + /// Outputs this argument to a stream. + void print(raw_ostream &os) const; + + /// Returns the kind of this argument. + DiagnosticArgumentKind getKind() const { return kind; } + + /// Returns this argument as an Attribute. + Attribute getAsAttribute() const; + + /// Returns this argument as a double. + double getAsDouble() const { + assert(getKind() == DiagnosticArgumentKind::Double); + return doubleVal; + } + + /// Returns this argument as a signed integer. + int64_t getAsInteger() const { + assert(getKind() == DiagnosticArgumentKind::Integer); + return static_cast(opaqueVal); + } + + /// Returns this argument as an operation. + Operation &getAsOperation() const { + assert(getKind() == DiagnosticArgumentKind::Operation); + return *reinterpret_cast(opaqueVal); + } + + /// Returns this argument as a string. + StringRef getAsString() const { + assert(getKind() == DiagnosticArgumentKind::String); + return stringVal; + } + + /// Returns this argument as a Type. + Type getAsType() const; + + /// Returns this argument as an unsigned integer. + uint64_t getAsUnsigned() const { + assert(getKind() == DiagnosticArgumentKind::Unsigned); + return static_cast(opaqueVal); + } + +private: + friend class Diagnostic; + + // Construct from an Attribute. + explicit DiagnosticArgument(Attribute attr); + + // Construct from a floating point number. + explicit DiagnosticArgument(double val) + : kind(DiagnosticArgumentKind::Double), doubleVal(val) {} + explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {} + + // Construct from a signed integer. + template + explicit DiagnosticArgument( + T val, typename std::enable_if::value && + std::numeric_limits::is_integer && + sizeof(T) <= sizeof(int64_t)>::type * = 0) + : kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {} + + // Construct from an unsigned integer. + template + explicit DiagnosticArgument( + T val, typename std::enable_if::value && + std::numeric_limits::is_integer && + sizeof(T) <= sizeof(uint64_t)>::type * = 0) + : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {} + + // Construct from an operation reference. + explicit DiagnosticArgument(Operation &val) : DiagnosticArgument(&val) {} + explicit DiagnosticArgument(Operation *val) + : kind(DiagnosticArgumentKind::Operation), + opaqueVal(reinterpret_cast(val)) { + assert(val && "expected valid operation"); + } + + // Construct from a string reference. + explicit DiagnosticArgument(StringRef val) + : kind(DiagnosticArgumentKind::String), stringVal(val) {} + + // Construct from a Type. + explicit DiagnosticArgument(Type val); + + /// The kind of this argument. + DiagnosticArgumentKind kind; + + /// The value of this argument. + union { + double doubleVal; + intptr_t opaqueVal; + StringRef stringVal; + }; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) { + arg.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// Diagnostic +//===----------------------------------------------------------------------===// + +/// This class contains all of the information necessary to report a diagnostic +/// to the DiagnosticEngine. It should generally not be constructed directly, +/// and instead used transitively via InFlightDiagnostic. +class Diagnostic { + using NoteVector = std::vector>; + + /// This class implements a wrapper iterator around NoteVector::iterator to + /// implicitly dereference the unique_ptr. + template + class NoteIteratorImpl + : public llvm::mapped_iterator { + static ResultTy &unwrap(NotePtrTy note) { return *note; } + + public: + NoteIteratorImpl(IteratorTy it) + : llvm::mapped_iterator(it, + &unwrap) {} + }; + +public: + Diagnostic(Location loc, DiagnosticSeverity severity) + : loc(loc), severity(severity) {} + Diagnostic(Diagnostic &&) = default; + Diagnostic &operator=(Diagnostic &&) = default; + + /// Returns the severity of this diagnostic. + DiagnosticSeverity getSeverity() const { return severity; } + + /// Returns the source location for this diagnostic. + Location getLocation() const { return loc; } + + /// Returns the current list of diagnostic arguments. + MutableArrayRef getArguments() { return arguments; } + ArrayRef getArguments() const { return arguments; } + + /// Stream operator for inserting new diagnostic arguments. + template + typename std::enable_if::value, + Diagnostic &>::type + operator<<(Arg &&val) { + arguments.push_back(DiagnosticArgument(std::forward(val))); + return *this; + } + + /// Stream in a string literal. + Diagnostic &operator<<(const char *val) { + arguments.push_back(DiagnosticArgument(val)); + return *this; + } + + /// Stream in a Twine argument. + Diagnostic &operator<<(char val); + Diagnostic &operator<<(const Twine &val); + Diagnostic &operator<<(Twine &&val); + + /// Stream in an Identifier. + Diagnostic &operator<<(Identifier val); + + /// Stream in an OperationName. + Diagnostic &operator<<(OperationName val); + + /// Stream in a range. + template Diagnostic &operator<<(iterator_range range) { + return appendRange(range); + } + template Diagnostic &operator<<(ArrayRef range) { + return appendRange(range); + } + + /// Append a range to the diagnostic. The default delimiter between elements + /// is ','. + template class Container> + Diagnostic &appendRange(const Container &c, const char *delim = ", ") { + interleave( + c, [&](const detail::ValueOfRange> &a) { *this << a; }, + [&]() { *this << delim; }); + return *this; + } + + /// Append arguments to the diagnostic. + template + Diagnostic &append(Arg1 &&arg1, Arg2 &&arg2, Args &&... args) { + append(std::forward(arg1)); + return append(std::forward(arg2), std::forward(args)...); + } + /// Append one argument to the diagnostic. + template Diagnostic &append(Arg &&arg) { + *this << std::forward(arg); + return *this; + } + + /// Outputs this diagnostic to a stream. + void print(raw_ostream &os) const; + + /// Converts the diagnostic to a string. + std::string str() const; + + /// Attaches a note to this diagnostic. A new location may be optionally + /// provided, if not, then the location defaults to the one specified for this + /// diagnostic. Notes may not be attached to other notes. + Diagnostic &attachNote(Optional noteLoc = llvm::None); + + using note_iterator = NoteIteratorImpl; + using const_note_iterator = NoteIteratorImpl; + + /// Returns the notes held by this diagnostic. + iterator_range getNotes() { + return {notes.begin(), notes.end()}; + } + iterator_range getNotes() const { + return {notes.begin(), notes.end()}; + } + + /// Allow a diagnostic to be converted to 'failure'. + operator LogicalResult() const; + +private: + Diagnostic(const Diagnostic &rhs) = delete; + Diagnostic &operator=(const Diagnostic &rhs) = delete; + + /// The source location. + Location loc; + + /// The severity of this diagnostic. + DiagnosticSeverity severity; + + /// The current list of arguments. + SmallVector arguments; + + /// A list of string values used as arguments. This is used to guarantee the + /// liveness of non-constant strings used in diagnostics. + std::vector> strings; + + /// A list of attached notes. + NoteVector notes; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const Diagnostic &diag) { + diag.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// InFlightDiagnostic +//===----------------------------------------------------------------------===// + +/// This class represents a diagnostic that is inflight and set to be reported. +/// This allows for last minute modifications of the diagnostic before it is +/// emitted by a DiagnosticEngine. +class InFlightDiagnostic { +public: + InFlightDiagnostic() = default; + InFlightDiagnostic(InFlightDiagnostic &&rhs) + : owner(rhs.owner), impl(std::move(rhs.impl)) { + // Reset the rhs diagnostic. + rhs.impl.reset(); + rhs.abandon(); + } + ~InFlightDiagnostic() { + if (isInFlight()) + report(); + } + + /// Stream operator for new diagnostic arguments. + template InFlightDiagnostic &operator<<(Arg &&arg) & { + return append(std::forward(arg)); + } + template InFlightDiagnostic &&operator<<(Arg &&arg) && { + return std::move(append(std::forward(arg))); + } + + /// Append arguments to the diagnostic. + template InFlightDiagnostic &append(Args &&... args) & { + assert(isActive() && "diagnostic not active"); + if (isInFlight()) + impl->append(std::forward(args)...); + return *this; + } + template InFlightDiagnostic &&append(Args &&... args) && { + return std::move(append(std::forward(args)...)); + } + + /// Attaches a note to this diagnostic. + Diagnostic &attachNote(Optional noteLoc = llvm::None) { + assert(isActive() && "diagnostic not active"); + return impl->attachNote(noteLoc); + } + + /// Reports the diagnostic to the engine. + void report(); + + /// Abandons this diagnostic so that it will no longer be reported. + void abandon(); + + /// Allow an inflight diagnostic to be converted to 'failure', otherwise + /// 'success' if this is an empty diagnostic. + operator LogicalResult() const; + +private: + InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete; + InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete; + InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs) + : owner(owner), impl(std::move(rhs)) {} + + /// Returns if the diagnostic is still active, i.e. it has a live diagnostic. + bool isActive() const { return impl.hasValue(); } + + /// Returns if the diagnostic is still in flight to be reported. + bool isInFlight() const { return owner; } + + // Allow access to the constructor. + friend DiagnosticEngine; + + /// The engine that this diagnostic is to report to. + DiagnosticEngine *owner = nullptr; + + /// The raw diagnostic that is inflight to be reported. + Optional impl; +}; + +//===----------------------------------------------------------------------===// +// DiagnosticEngine +//===----------------------------------------------------------------------===// + +/// This class is the main interface for diagnostics. The DiagnosticEngine +/// manages the registration of diagnostic handlers as well as the core API for +/// diagnostic emission. This class should not be constructed directly, but +/// instead interfaced with via an MLIRContext instance. +class DiagnosticEngine { +public: + ~DiagnosticEngine(); + + // Diagnostic handler registration and use. MLIR supports the ability for the + // IR to carry arbitrary metadata about operation location information. If a + // problem is detected by the compiler, it can invoke the emitError / + // emitWarning / emitRemark method on an Operation and have it get reported + // through this interface. + // + // Tools using MLIR are encouraged to register error handlers and define a + // schema for their location information. If they don't, then warnings and + // notes will be dropped and errors will be emitted to errs. + + /// The handler type for MLIR diagnostics. This function takes a diagnostic as + /// input, and returns success if the handler has fully processed this + /// diagnostic. Returns failure otherwise. + using HandlerTy = std::function; + + /// A handle to a specific registered handler object. + using HandlerID = uint64_t; + + /// Register a new handler for diagnostics to the engine. Diagnostics are + /// process by handlers in stack-like order, meaning that the last added + /// handlers will process diagnostics first. This function returns a unique + /// identifier for the registered handler, which can be used to unregister + /// this handler at a later time. + HandlerID registerHandler(const HandlerTy &handler); + + /// Set the diagnostic handler with a function that returns void. This is a + /// convenient wrapper for handlers that always completely process the given + /// diagnostic. + template ()( + std::declval()))> + std::enable_if_t::value, HandlerID> + registerHandler(FuncTy &&handler) { + return registerHandler([=](Diagnostic &diag) { + handler(diag); + return success(); + }); + } + + /// Erase the registered diagnostic handler with the given identifier. + void eraseHandler(HandlerID id); + + /// Create a new inflight diagnostic with the given location and severity. + InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity) { + assert(severity != DiagnosticSeverity::Note && + "notes should not be emitted directly"); + return InFlightDiagnostic(this, Diagnostic(loc, severity)); + } + + /// Emit a diagnostic using the registered issue handler if present, or with + /// the default behavior if not. + void emit(Diagnostic diag); + +private: + friend class MLIRContextImpl; + DiagnosticEngine(); + + /// The internal implementation of the DiagnosticEngine. + std::unique_ptr impl; +}; + +/// Utility method to emit an error message using this location. +InFlightDiagnostic emitError(Location loc); +InFlightDiagnostic emitError(Location loc, const Twine &message); + +/// Utility method to emit a warning message using this location. +InFlightDiagnostic emitWarning(Location loc); +InFlightDiagnostic emitWarning(Location loc, const Twine &message); + +/// Utility method to emit a remark message using this location. +InFlightDiagnostic emitRemark(Location loc); +InFlightDiagnostic emitRemark(Location loc, const Twine &message); + +/// Overloads of the above emission functions that take an optionally null +/// location. If the location is null, no diagnostic is emitted and a failure is +/// returned. Given that the provided location may be null, these methods take +/// the diagnostic arguments directly instead of relying on the returned +/// InFlightDiagnostic. +template +LogicalResult emitOptionalError(Optional loc, Args &&... args) { + if (loc) + return emitError(*loc).append(std::forward(args)...); + return failure(); +} +template +LogicalResult emitOptionalWarning(Optional loc, Args &&... args) { + if (loc) + return emitWarning(*loc).append(std::forward(args)...); + return failure(); +} +template +LogicalResult emitOptionalRemark(Optional loc, Args &&... args) { + if (loc) + return emitRemark(*loc).append(std::forward(args)...); + return failure(); +} + +//===----------------------------------------------------------------------===// +// ScopedDiagnosticHandler +//===----------------------------------------------------------------------===// + +/// This diagnostic handler is a simple RAII class that registers and erases a +/// diagnostic handler on a given context. This class can be either be used +/// directly, or in conjunction with a derived diagnostic handler. +class ScopedDiagnosticHandler { +public: + explicit ScopedDiagnosticHandler(MLIRContext *ctx) : handlerID(0), ctx(ctx) {} + template + ScopedDiagnosticHandler(MLIRContext *ctx, FuncTy &&handler) + : handlerID(0), ctx(ctx) { + setHandler(std::forward(handler)); + } + ~ScopedDiagnosticHandler(); + +protected: + /// Set the handler to manage via RAII. + template void setHandler(FuncTy &&handler) { + auto &diagEngine = ctx->getDiagEngine(); + if (handlerID) + diagEngine.eraseHandler(handlerID); + handlerID = diagEngine.registerHandler(std::forward(handler)); + } + +private: + /// The unique id for the scoped handler. + DiagnosticEngine::HandlerID handlerID; + + /// The context to erase the handler from. + MLIRContext *ctx; +}; + +//===----------------------------------------------------------------------===// +// SourceMgrDiagnosticHandler +//===----------------------------------------------------------------------===// + +namespace detail { +struct SourceMgrDiagnosticHandlerImpl; +} // end namespace detail + +/// This class is a utility diagnostic handler for use with llvm::SourceMgr. +class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler { +public: + SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx, + raw_ostream &os); + SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx); + ~SourceMgrDiagnosticHandler(); + + /// Emit the given diagnostic information with the held source manager. + void emitDiagnostic(Location loc, Twine message, DiagnosticSeverity kind); + +protected: + /// Emit the given diagnostic with the held source manager. + void emitDiagnostic(Diagnostic &diag); + + /// Get a memory buffer for the given file, or nullptr if no file is + /// available. + const llvm::MemoryBuffer *getBufferForFile(StringRef filename); + + /// The source manager that we are wrapping. + llvm::SourceMgr &mgr; + + /// The output stream to use when printing diagnostics. + raw_ostream &os; + +private: + /// Convert a location into the given memory buffer into an SMLoc. + llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc); + + /// The maximum depth that a call stack will be printed. + /// TODO(riverriddle) This should be a tunable flag. + unsigned callStackLimit = 10; + + std::unique_ptr impl; +}; + +//===----------------------------------------------------------------------===// +// SourceMgrDiagnosticVerifierHandler +//===----------------------------------------------------------------------===// + +namespace detail { +struct SourceMgrDiagnosticVerifierHandlerImpl; +} // end namespace detail + +/// This class is a utility diagnostic handler for use with llvm::SourceMgr that +/// verifies that emitted diagnostics match 'expected-*' lines on the +/// corresponding line of the source file. +class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler { +public: + SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx, + raw_ostream &out); + SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx); + ~SourceMgrDiagnosticVerifierHandler(); + + /// Returns the status of the handler and verifies that all expected + /// diagnostics were emitted. This return success if all diagnostics were + /// verified correctly, failure otherwise. + LogicalResult verify(); + +private: + /// Process a single diagnostic. + void process(Diagnostic &diag); + + /// Process a FileLineColLoc diagnostic. + void process(FileLineColLoc loc, StringRef msg, DiagnosticSeverity kind); + + std::unique_ptr impl; +}; + +//===----------------------------------------------------------------------===// +// ParallelDiagnosticHandler +//===----------------------------------------------------------------------===// + +namespace detail { +struct ParallelDiagnosticHandlerImpl; +} // end namespace detail + +/// This class is a utility diagnostic handler for use when multi-threading some +/// part of the compiler where diagnostics may be emitted. This handler ensures +/// a deterministic ordering to the emitted diagnostics that mirrors that of a +/// single-threaded compilation. +class ParallelDiagnosticHandler { +public: + ParallelDiagnosticHandler(MLIRContext *ctx); + ~ParallelDiagnosticHandler(); + + /// Set the order id for the current thread. This is required to be set by + /// each thread that will be emitting diagnostics to this handler. The orderID + /// corresponds to the order in which diagnostics would be emitted when + /// executing synchronously. For example, if we were processing a list + /// of operations [a, b, c] on a single-thread. Diagnostics emitted while + /// processing operation 'a' would be emitted before those for 'b' or 'c'. + /// This corresponds 1-1 with the 'orderID'. The thread that is processing 'a' + /// should set the orderID to '0'; the thread processing 'b' should set it to + /// '1'; and so on and so forth. This provides a way for the handler to + /// deterministically order the diagnostics that it receives given the thread + /// that it is receiving on. + void setOrderIDForThread(size_t orderID); + + /// Remove the order id for the current thread. This removes the thread from + /// diagnostics tracking. + void eraseOrderIDForThread(); + +private: + std::unique_ptr impl; +}; +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..d3b4b055bc0c96ba221432e26b787f98f04fe164 --- /dev/null +++ b/mlir/include/mlir/IR/Dialect.h @@ -0,0 +1,315 @@ +//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the 'dialect' abstraction. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECT_H +#define MLIR_IR_DIALECT_H + +#include "mlir/IR/OperationSupport.h" + +namespace mlir { +class DialectAsmParser; +class DialectAsmPrinter; +class DialectInterface; +class OpBuilder; +class Type; + +using DialectConstantDecodeHook = + std::function; +using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; +using DialectExtractElementHook = + std::function)>; + +/// Dialects are groups of MLIR operations and behavior associated with the +/// entire group. For example, hooks into other systems for constant folding, +/// default named types for asm printing, etc. +/// +/// Instances of the dialect object are global across all MLIRContext's that may +/// be active in the process. +/// +class Dialect { +public: + virtual ~Dialect(); + + /// Utility function that returns if the given string is a valid dialect + /// namespace. + static bool isValidNamespace(StringRef str); + + MLIRContext *getContext() const { return context; } + + StringRef getNamespace() const { return name; } + + /// Returns true if this dialect allows for unregistered operations, i.e. + /// operations prefixed with the dialect namespace but not registered with + /// addOperation. + bool allowsUnknownOperations() const { return unknownOpsAllowed; } + + /// Return true if this dialect allows for unregistered types, i.e., types + /// prefixed with the dialect namespace but not registered with addType. + /// These are represented with OpaqueType. + bool allowsUnknownTypes() const { return unknownTypesAllowed; } + + //===--------------------------------------------------------------------===// + // Constant Hooks + //===--------------------------------------------------------------------===// + + /// Registered fallback constant fold hook for the dialect. Like the constant + /// fold hook of each operation, it attempts to constant fold the operation + /// with the specified constant operand values - the elements in "operands" + /// will correspond directly to the operands of the operation, but may be null + /// if non-constant. If constant folding is successful, this fills in the + /// `results` vector. If not, this returns failure and `results` is + /// unspecified. + DialectConstantFoldHook constantFoldHook = + [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { return failure(); }; + + /// Registered hook to decode opaque constants associated with this + /// dialect. The hook function attempts to decode an opaque constant tensor + /// into a tensor with non-opaque content. If decoding is successful, this + /// method returns false and sets 'output' attribute. If not, it returns true + /// and leaves 'output' unspecified. The default hook fails to decode. + DialectConstantDecodeHook decodeHook = + [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; }; + + /// Registered hook to extract an element from an opaque constant associated + /// with this dialect. If element has been successfully extracted, this + /// method returns that element. If not, it returns an empty attribute. + /// The default hook fails to extract an element. + DialectExtractElementHook extractElementHook = + [](const OpaqueElementsAttr input, ArrayRef index) { + return Attribute(); + }; + + /// Registered hook to materialize a single constant operation from a given + /// attribute value with the desired resultant type. This method should use + /// the provided builder to create the operation without changing the + /// insertion position. The generated operation is expected to be constant + /// like, i.e. single result, zero operands, non side-effecting, etc. On + /// success, this hook should return the value generated to represent the + /// constant value. Otherwise, it should return null on failure. + virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return nullptr; + } + + //===--------------------------------------------------------------------===// + // Parsing Hooks + //===--------------------------------------------------------------------===// + + /// Parse an attribute registered to this dialect. If 'type' is nonnull, it + /// refers to the expected type of the attribute. + virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; + + /// Print an attribute registered to this dialect. Note: The type of the + /// attribute need not be printed by this method as it is always printed by + /// the caller. + virtual void printAttribute(Attribute, DialectAsmPrinter &) const { + llvm_unreachable("dialect has no registered attribute printing hook"); + } + + /// Parse a type registered to this dialect. + virtual Type parseType(DialectAsmParser &parser) const; + + /// Print a type registered to this dialect. + virtual void printType(Type, DialectAsmPrinter &) const { + llvm_unreachable("dialect has no registered type printing hook"); + } + + //===--------------------------------------------------------------------===// + // Verification Hooks + //===--------------------------------------------------------------------===// + + /// Verify an attribute from this dialect on the argument at 'argIndex' for + /// the region at 'regionIndex' on the given operation. Returns failure if + /// the verification failed, success otherwise. This hook may optionally be + /// invoked from any operation containing a region. + virtual LogicalResult verifyRegionArgAttribute(Operation *, + unsigned regionIndex, + unsigned argIndex, + NamedAttribute); + + /// Verify an attribute from this dialect on the result at 'resultIndex' for + /// the region at 'regionIndex' on the given operation. Returns failure if + /// the verification failed, success otherwise. This hook may optionally be + /// invoked from any operation containing a region. + virtual LogicalResult verifyRegionResultAttribute(Operation *, + unsigned regionIndex, + unsigned resultIndex, + NamedAttribute); + + /// Verify an attribute from this dialect on the given operation. Returns + /// failure if the verification failed, success otherwise. + virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { + return success(); + } + + //===--------------------------------------------------------------------===// + // Interfaces + //===--------------------------------------------------------------------===// + + /// Lookup an interface for the given ID if one is registered, otherwise + /// nullptr. + const DialectInterface *getRegisteredInterface(ClassID *interfaceID) { + auto it = registeredInterfaces.find(interfaceID); + return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; + } + template const InterfaceT *getRegisteredInterface() { + return static_cast( + getRegisteredInterface(InterfaceT::getInterfaceID())); + } + +protected: + /// The constructor takes a unique namespace for this dialect as well as the + /// context to bind to. + /// Note: The namespace must not contain '.' characters. + /// Note: All operations belonging to this dialect must have names starting + /// with the namespace followed by '.'. + /// Example: + /// - "tf" for the TensorFlow ops like "tf.add". + Dialect(StringRef name, MLIRContext *context); + + /// This method is used by derived classes to add their operations to the set. + /// + template void addOperations() { + VariadicOperationAdder::addToSet(*this); + } + + // It would be nice to define this as variadic functions instead of a nested + // variadic type, but we can't do that: function template partial + // specialization is not allowed, and we can't define an overload set because + // we don't have any arguments of the types we are pushing around. + template class VariadicOperationAdder { + public: + static void addToSet(Dialect &dialect) { + dialect.addOperation(AbstractOperation::get(dialect)); + VariadicOperationAdder::addToSet(dialect); + } + }; + + template class VariadicOperationAdder { + public: + static void addToSet(Dialect &dialect) { + dialect.addOperation(AbstractOperation::get(dialect)); + } + }; + + void addOperation(AbstractOperation opInfo); + + /// This method is used by derived classes to add their types to the set. + template void addTypes() { + VariadicSymbolAdder::addToSet(*this); + } + + /// This method is used by derived classes to add their attributes to the set. + template void addAttributes() { + VariadicSymbolAdder::addToSet(*this); + } + + // It would be nice to define this as variadic functions instead of a nested + // variadic type, but we can't do that: function template partial + // specialization is not allowed, and we can't define an overload set + // because we don't have any arguments of the types we are pushing around. + template struct VariadicSymbolAdder { + static void addToSet(Dialect &dialect) { + VariadicSymbolAdder::addToSet(dialect); + VariadicSymbolAdder::addToSet(dialect); + } + }; + + template struct VariadicSymbolAdder { + static void addToSet(Dialect &dialect) { + dialect.addSymbol(First::getClassID()); + } + }; + + /// Enable support for unregistered operations. + void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } + + /// Enable support for unregistered types. + void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } + + /// Register a dialect interface with this dialect instance. + void addInterface(std::unique_ptr interface); + + /// Register a set of dialect interfaces with this dialect instance. + template void addInterfaces() { + addInterfaces(); + addInterfaces(); + } + template void addInterfaces() { + addInterface(std::make_unique(this)); + } + +private: + // Register a symbol(e.g. type) with its given unique class identifier. + void addSymbol(const ClassID *const classID); + + Dialect(const Dialect &) = delete; + void operator=(Dialect &) = delete; + + /// Register this dialect object with the specified context. The context + /// takes ownership of the heap allocated dialect. + void registerDialect(MLIRContext *context); + + /// The namespace of this dialect. + StringRef name; + + /// This is the context that owns this Dialect object. + MLIRContext *context; + + /// Flag that specifies whether this dialect supports unregistered operations, + /// i.e. operations prefixed with the dialect namespace but not registered + /// with addOperation. + bool unknownOpsAllowed = false; + + /// Flag that specifies whether this dialect allows unregistered types, i.e. + /// types prefixed with the dialect namespace but not registered with addType. + /// These types are represented with OpaqueType. + bool unknownTypesAllowed = false; + + /// A collection of registered dialect interfaces. + DenseMap> registeredInterfaces; +}; + +using DialectAllocatorFunction = std::function; + +/// Registers a specific dialect creation function with the system, typically +/// used through the DialectRegistration template. +void registerDialectAllocator(const DialectAllocatorFunction &function); + +/// Registers all dialects with the specified MLIRContext. +void registerAllDialects(MLIRContext *context); + +/// Utility to register a dialect. Client can register their dialect with the +/// global registry by calling registerDialect(); +template void registerDialect() { + registerDialectAllocator([](MLIRContext *ctx) { + // Just allocate the dialect, the context takes ownership of it. + new ConcreteDialect(ctx); + }); +} + +/// DialectRegistration provides a global initializer that registers a Dialect +/// allocation routine. +/// +/// Usage: +/// +/// // At namespace scope. +/// static DialectRegistration Unused; +template struct DialectRegistration { + DialectRegistration() { registerDialect(); } +}; + +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..7e4e1d8335b1645725e416085d5c7d4dac302620 --- /dev/null +++ b/mlir/include/mlir/IR/DialectHooks.h @@ -0,0 +1,73 @@ +//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines abstraction and registration mechanism for dialect hooks. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECT_HOOKS_H +#define MLIR_IR_DIALECT_HOOKS_H + +#include "mlir/IR/Dialect.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +using DialectHooksSetter = std::function; + +/// Dialect hooks allow external components to register their functions to +/// be called for specific tasks specialized per dialect, such as decoding +/// of opaque constants. To register concrete dialect hooks, one should +/// define a DialectHooks subclass and use it as a template +/// argument to DialectHooksRegistration. For example, +/// class MyHooks : public DialectHooks {...}; +/// static DialectHooksRegistration hooksReg; +/// The subclass should override DialectHook methods for supported hooks. +class DialectHooks { +public: + // Returns hook to constant fold an operation. + DialectConstantFoldHook getConstantFoldHook() { return nullptr; } + // Returns hook to decode opaque constant tensor. + DialectConstantDecodeHook getDecodeHook() { return nullptr; } + // Returns hook to extract an element of an opaque constant tensor. + DialectExtractElementHook getExtractElementHook() { return nullptr; } +}; + +/// Registers a function that will set hooks in the registered dialects +/// based on information coming from DialectHooksRegistration. +void registerDialectHooksSetter(const DialectHooksSetter &function); + +/// DialectHooksRegistration provides a global initializer that registers +/// a dialect hooks setter routine. +/// Usage: +/// +/// // At namespace scope. +/// static DialectHooksRegistration unused; +template struct DialectHooksRegistration { + DialectHooksRegistration(StringRef dialectName) { + registerDialectHooksSetter([dialectName](MLIRContext *ctx) { + Dialect *dialect = ctx->getRegisteredDialect(dialectName); + if (!dialect) { + llvm::errs() << "error: cannot register hooks for unknown dialect '" + << dialectName << "'\n"; + abort(); + } + // Set hooks. + ConcreteHooks hooks; + if (auto h = hooks.getConstantFoldHook()) + dialect->constantFoldHook = h; + if (auto h = hooks.getDecodeHook()) + dialect->decodeHook = h; + if (auto h = hooks.getExtractElementHook()) + dialect->extractElementHook = h; + }); + } +}; + +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h new file mode 100644 index 0000000000000000000000000000000000000000..1eada8f264b14c661fea272e5a32f9468682fed9 --- /dev/null +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -0,0 +1,333 @@ +//===- DialectImplementation.h ----------------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains utilities classes for implementing dialect attributes and +// types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECTIMPLEMENTATION_H +#define MLIR_IR_DIALECTIMPLEMENTATION_H + +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { + +class Builder; + +//===----------------------------------------------------------------------===// +// DialectAsmPrinter +//===----------------------------------------------------------------------===// + +/// This is a pure-virtual base class that exposes the asmprinter hooks +/// necessary to implement a custom printAttribute/printType() method on a +/// dialect. +class DialectAsmPrinter { +public: + DialectAsmPrinter() {} + virtual ~DialectAsmPrinter(); + virtual raw_ostream &getStream() const = 0; + + /// Print the given attribute to the stream. + virtual void printAttribute(Attribute attr) = 0; + + /// Print the given floating point value in a stabilized form that can be + /// roundtripped through the IR. This is the companion to the 'parseFloat' + /// hook on the DialectAsmParser. + virtual void printFloat(const APFloat &value) = 0; + + /// Print the given type to the stream. + virtual void printType(Type type) = 0; + +private: + DialectAsmPrinter(const DialectAsmPrinter &) = delete; + void operator=(const DialectAsmPrinter &) = delete; +}; + +// Make the implementations convenient to use. +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) { + p.printAttribute(attr); + return p; +} + +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, + const APFloat &value) { + p.printFloat(value); + return p; +} +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) { + return p << APFloat(value); +} +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) { + return p << APFloat(value); +} + +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) { + p.printType(type); + return p; +} + +// Support printing anything that isn't convertible to one of the above types, +// even if it isn't exactly one of them. For example, we want to print +// FunctionType with the Type version above, not have it match this. +template ::value && + !std::is_convertible::value && + !std::is_convertible::value && + !llvm::is_one_of::value, + T>::type * = nullptr> +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) { + p.getStream() << other; + return p; +} + +//===----------------------------------------------------------------------===// +// DialectAsmParser +//===----------------------------------------------------------------------===// + +/// The DialectAsmParser has methods for interacting with the asm parser: +/// parsing things from it, emitting errors etc. It has an intentionally +/// high-level API that is designed to reduce/constrain syntax innovation in +/// individual attributes or types. +class DialectAsmParser { +public: + virtual ~DialectAsmParser(); + + /// Emit a diagnostic at the specified location and return failure. + virtual InFlightDiagnostic emitError(llvm::SMLoc loc, + const Twine &message = {}) = 0; + + /// Return a builder which provides useful access to MLIRContext, global + /// objects like types and attributes. + virtual Builder &getBuilder() const = 0; + + /// Get the location of the next token and store it into the argument. This + /// always succeeds. + virtual llvm::SMLoc getCurrentLocation() = 0; + ParseResult getCurrentLocation(llvm::SMLoc *loc) { + *loc = getCurrentLocation(); + return success(); + } + + /// Return the location of the original name token. + virtual llvm::SMLoc getNameLoc() const = 0; + + /// Re-encode the given source location as an MLIR location and return it. + virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; + + /// Returns the full specification of the symbol being parsed. This allows for + /// using a separate parser if necessary. + virtual StringRef getFullSymbolSpec() const = 0; + + // These methods emit an error and return failure or success. This allows + // these to be chained together into a linear sequence of || expressions in + // many cases. + + /// Parse a floating point value from the stream. + virtual ParseResult parseFloat(double &result) = 0; + + /// Parse an integer value from the stream. + template ParseResult parseInteger(IntT &result) { + auto loc = getCurrentLocation(); + OptionalParseResult parseResult = parseOptionalInteger(result); + if (!parseResult.hasValue()) + return emitError(loc, "expected integer value"); + return *parseResult; + } + + /// Parse an optional integer value from the stream. + virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; + + template + OptionalParseResult parseOptionalInteger(IntT &result) { + auto loc = getCurrentLocation(); + + // Parse the unsigned variant. + uint64_t uintResult; + OptionalParseResult parseResult = parseOptionalInteger(uintResult); + if (!parseResult.hasValue() || failed(*parseResult)) + return parseResult; + + // Try to convert to the provided integer type. + result = IntT(uintResult); + if (uint64_t(result) != uintResult) + return emitError(loc, "integer value too large"); + return success(); + } + + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + + /// Parse a '->' token. + virtual ParseResult parseArrow() = 0; + + /// Parse a '->' token if present + virtual ParseResult parseOptionalArrow() = 0; + + /// Parse a '{' token. + virtual ParseResult parseLBrace() = 0; + + /// Parse a '{' token if present + virtual ParseResult parseOptionalLBrace() = 0; + + /// Parse a `}` token. + virtual ParseResult parseRBrace() = 0; + + /// Parse a `}` token if present + virtual ParseResult parseOptionalRBrace() = 0; + + /// Parse a `:` token. + virtual ParseResult parseColon() = 0; + + /// Parse a `:` token if present. + virtual ParseResult parseOptionalColon() = 0; + + /// Parse a `,` token. + virtual ParseResult parseComma() = 0; + + /// Parse a `,` token if present. + virtual ParseResult parseOptionalComma() = 0; + + /// Parse a `=` token. + virtual ParseResult parseEqual() = 0; + + /// Parse a given keyword. + ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { + auto loc = getCurrentLocation(); + if (parseOptionalKeyword(keyword)) + return emitError(loc, "expected '") << keyword << "'" << msg; + return success(); + } + + /// Parse a keyword into 'keyword'. + ParseResult parseKeyword(StringRef *keyword) { + auto loc = getCurrentLocation(); + if (parseOptionalKeyword(keyword)) + return emitError(loc, "expected valid keyword"); + return success(); + } + + /// Parse the given keyword if present. + virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; + + /// Parse a keyword, if present, into 'keyword'. + virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; + + /// Parse a '<' token. + virtual ParseResult parseLess() = 0; + + /// Parse a `<` token if present. + virtual ParseResult parseOptionalLess() = 0; + + /// Parse a '>' token. + virtual ParseResult parseGreater() = 0; + + /// Parse a `>` token if present. + virtual ParseResult parseOptionalGreater() = 0; + + /// Parse a `(` token. + virtual ParseResult parseLParen() = 0; + + /// Parse a `(` token if present. + virtual ParseResult parseOptionalLParen() = 0; + + /// Parse a `)` token. + virtual ParseResult parseRParen() = 0; + + /// Parse a `)` token if present. + virtual ParseResult parseOptionalRParen() = 0; + + /// Parse a `[` token. + virtual ParseResult parseLSquare() = 0; + + /// Parse a `[` token if present. + virtual ParseResult parseOptionalLSquare() = 0; + + /// Parse a `]` token. + virtual ParseResult parseRSquare() = 0; + + /// Parse a `]` token if present. + virtual ParseResult parseOptionalRSquare() = 0; + + /// Parse a `...` token if present; + virtual ParseResult parseOptionalEllipsis() = 0; + + /// Parse a `?` token. + virtual ParseResult parseOptionalQuestion() = 0; + + /// Parse a `*` token. + virtual ParseResult parseOptionalStar() = 0; + + //===--------------------------------------------------------------------===// + // Attribute Parsing + //===--------------------------------------------------------------------===// + + /// Parse an arbitrary attribute and return it in result. + virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; + + /// Parse an attribute of a specific kind and type. + template + ParseResult parseAttribute(AttrType &result, Type type = {}) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of attribute. + Attribute attr; + if (parseAttribute(attr)) + return failure(); + + // Check for the right kind of attribute. + result = attr.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of attribute specified"); + return success(); + } + + //===--------------------------------------------------------------------===// + // Type Parsing + //===--------------------------------------------------------------------===// + + /// Parse a type. + virtual ParseResult parseType(Type &result) = 0; + + /// Parse a type of a specific kind, e.g. a FunctionType. + template ParseResult parseType(TypeType &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of type. + Type type; + if (parseType(type)) + return failure(); + + // Check for the right kind of attribute. + result = type.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of type specified"); + return success(); + } + + /// Parse a 'x' separated dimension list. This populates the dimension list, + /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on + /// `?` otherwise. + /// + /// dimension-list ::= (dimension `x`)* + /// dimension ::= `?` | integer + /// + /// When `allowDynamic` is not set, this is used to parse: + /// + /// static-dimension-list ::= (integer `x`)* + virtual ParseResult parseDimensionList(SmallVectorImpl &dimensions, + bool allowDynamic = true) = 0; +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..ff1f8fb015abda7c68e31cc04702f03b1e78b1c4 --- /dev/null +++ b/mlir/include/mlir/IR/DialectInterface.h @@ -0,0 +1,181 @@ +//===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECTINTERFACE_H +#define MLIR_IR_DIALECTINTERFACE_H + +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { +class Dialect; +class MLIRContext; +class Operation; + +//===----------------------------------------------------------------------===// +// DialectInterface +//===----------------------------------------------------------------------===// +namespace detail { +/// The base class used for all derived interface types. This class provides +/// utilities necessary for registration. +template +class DialectInterfaceBase : public BaseT { +public: + using Base = DialectInterfaceBase; + + /// Get a unique id for the derived interface type. + static ClassID *getInterfaceID() { return ClassID::getID(); } + +protected: + DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} +}; +} // end namespace detail + +/// This class represents an interface overridden for a single dialect. +class DialectInterface { +public: + virtual ~DialectInterface(); + + /// The base class used for all derived interface types. This class provides + /// utilities necessary for registration. + template + using Base = detail::DialectInterfaceBase; + + /// Return the dialect that this interface represents. + Dialect *getDialect() const { return dialect; } + + /// Return the derived interface id. + ClassID *getID() const { return interfaceID; } + +protected: + DialectInterface(Dialect *dialect, ClassID *id) + : dialect(dialect), interfaceID(id) {} + +private: + /// The dialect that represents this interface. + Dialect *dialect; + + /// The unique identifier for the derived interface type. + ClassID *interfaceID; +}; + +//===----------------------------------------------------------------------===// +// DialectInterfaceCollection +//===----------------------------------------------------------------------===// + +namespace detail { +/// This class is the base class for a collection of instances for a specific +/// interface kind. +class DialectInterfaceCollectionBase { + /// DenseMap info for dialect interfaces that allows lookup by the dialect. + struct InterfaceKeyInfo : public DenseMapInfo { + using DenseMapInfo::isEqual; + + static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } + static unsigned getHashValue(const DialectInterface *key) { + return getHashValue(key->getDialect()); + } + + static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs == rhs->getDialect(); + } + }; + + /// A set of registered dialect interface instances. + using InterfaceSetT = DenseSet; + using InterfaceVectorT = std::vector; + +public: + DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind); + virtual ~DialectInterfaceCollectionBase(); + +protected: + /// Get the interface for the dialect of given operation, or null if one + /// is not registered. + const DialectInterface *getInterfaceFor(Operation *op) const; + + /// Get the interface for the given dialect. + const DialectInterface *getInterfaceFor(Dialect *dialect) const { + auto it = interfaces.find_as(dialect); + return it == interfaces.end() ? nullptr : *it; + } + + /// An iterator class that iterates the held interface objects of the given + /// derived interface type. + template + class iterator : public llvm::mapped_iterator< + InterfaceVectorT::const_iterator, + const InterfaceT &(*)(const DialectInterface *)> { + static const InterfaceT &remapIt(const DialectInterface *interface) { + return *static_cast(interface); + } + + iterator(InterfaceVectorT::const_iterator it) + : llvm::mapped_iterator< + InterfaceVectorT::const_iterator, + const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {} + + /// Allow access to the constructor. + friend DialectInterfaceCollectionBase; + }; + + /// Iterator access to the held interfaces. + template iterator interface_begin() const { + return iterator(orderedInterfaces.begin()); + } + template iterator interface_end() const { + return iterator(orderedInterfaces.end()); + } + +private: + /// A set of registered dialect interface instances. + InterfaceSetT interfaces; + /// An ordered list of the registered interface instances, necessary for + /// deterministic iteration. + // NOTE: SetVector does not provide find access, so it can't be used here. + InterfaceVectorT orderedInterfaces; +}; +} // namespace detail + +/// A collection of dialect interfaces within a context, for a given concrete +/// interface type. +template +class DialectInterfaceCollection + : public detail::DialectInterfaceCollectionBase { +public: + using Base = DialectInterfaceCollection; + + /// Collect the registered dialect interfaces within the provided context. + DialectInterfaceCollection(MLIRContext *ctx) + : detail::DialectInterfaceCollectionBase( + ctx, InterfaceType::getInterfaceID()) {} + + /// Get the interface for a given object, or null if one is not registered. + /// The object may be a dialect or an operation instance. + template + const InterfaceType *getInterfaceFor(Object *obj) const { + return static_cast( + detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); + } + + /// Iterator access to the held interfaces. + using iterator = + detail::DialectInterfaceCollectionBase::iterator; + iterator begin() const { return interface_begin(); } + iterator end() const { return interface_end(); } + +private: + using detail::DialectInterfaceCollectionBase::interface_begin; + using detail::DialectInterfaceCollectionBase::interface_end; +}; + +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def new file mode 100644 index 0000000000000000000000000000000000000000..14b876a2ce91ed07dc24c4e071b1242e1d011945 --- /dev/null +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -0,0 +1,41 @@ +//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file enumerates the different dialects that define custom classes +// within the attribute or type system. +// +//===----------------------------------------------------------------------===// + +DEFINE_SYM_KIND_RANGE(STANDARD) +DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL) +DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR) +DEFINE_SYM_KIND_RANGE(TENSORFLOW) +DEFINE_SYM_KIND_RANGE(LLVM) +DEFINE_SYM_KIND_RANGE(QUANTIZATION) +DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine +DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect +DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect +DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect +DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect +DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect +DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect + +// The following ranges are reserved for experimenting with MLIR dialects in a +// private context without having to register them here. +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8) +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9) + +#undef DEFINE_SYM_KIND_RANGE diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h new file mode 100644 index 0000000000000000000000000000000000000000..3f788bbeeba4ebe2e9c8da50cc7a08d957e4a2a7 --- /dev/null +++ b/mlir/include/mlir/IR/Function.h @@ -0,0 +1,201 @@ +//===- Function.h - MLIR Function Class -------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Functions are the basic unit of composition in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_FUNCTION_H +#define MLIR_IR_FUNCTION_H + +#include "mlir/Analysis/CallInterfaces.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +//===--------------------------------------------------------------------===// +// Function Operation. +//===--------------------------------------------------------------------===// + +/// FuncOp represents a function, or an operation containing one region that +/// forms a CFG(Control Flow Graph). The region of a function is not allowed to +/// implicitly capture global values, and all external references must use +/// Function arguments or attributes that establish a symbolic connection(e.g. +/// symbols referenced by name via a string attribute). +class FuncOp : public Op { +public: + using Op::Op; + using Op::print; + + static StringRef getOperationName() { return "func"; } + + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}); + static FuncOp create(Location location, StringRef name, FunctionType type, + iterator_range attrs); + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs); + + static void build(Builder *builder, OperationState &result, StringRef name, + FunctionType type, ArrayRef attrs); + static void build(Builder *builder, OperationState &result, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs); + + /// Operation hooks. + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } + /// Erases the arguments listed in `argIndices`. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices); + + /// Returns the type of this function. + FunctionType getType() { + return getAttrOfType(getTypeAttrName()) + .getValue() + .cast(); + } + + /// Change the type of this function in place. This is an extremely dangerous + /// operation and it is up to the caller to ensure that this is legal for this + /// function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type has + /// less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + void setType(FunctionType newType) { + SmallVector nameBuf; + auto oldType = getType(); + for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; + i++) { + removeAttr(getArgAttrName(i, nameBuf)); + } + for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; + i++) { + removeAttr(getResultAttrName(i, nameBuf)); + } + setAttr(getTypeAttrName(), TypeAttr::get(newType)); + } + + /// Create a deep copy of this function and all of its blocks, remapping + /// any operands that use values outside of the function using the map that is + /// provided (leaving them alone if no entry is present). If the mapper + /// contains entries for function arguments, these arguments are not included + /// in the new function. Replaces references to cloned sub-values with the + /// corresponding value that is copied, and adds those mappings to the mapper. + FuncOp clone(BlockAndValueMapping &mapper); + FuncOp clone(); + + /// Clone the internal blocks and attributes from this function into dest. Any + /// cloned blocks are appended to the back of dest. This function asserts that + /// the attributes of the current function and dest are compatible. + void cloneInto(FuncOp dest, BlockAndValueMapping &mapper); + + //===--------------------------------------------------------------------===// + // Body Handling + //===--------------------------------------------------------------------===// + + /// Add an entry block to an empty function, and set up the block arguments + /// to match the signature of the function. The newly inserted entry block is + /// returned. + Block *addEntryBlock(); + + /// Add a normal block to the end of the function's block list. The function + /// should at least already have an entry block. + Block *addBlock(); + + //===--------------------------------------------------------------------===// + // CallableOpInterface + //===--------------------------------------------------------------------===// + + /// Returns a region on the current operation that the given callable refers + /// to. This may return null in the case of an external callable object, e.g. + /// an external function. + Region *getCallableRegion(CallInterfaceCallable callable) { + assert(callable.get().getLeafReference() == getName()); + return isExternal() ? nullptr : &getBody(); + } + + /// Returns all of the callable regions of this operation. + void getCallableRegions(SmallVectorImpl &callables) { + if (!isExternal()) + callables.push_back(&getBody()); + } + + /// Returns the results types that the given callable region produces when + /// executed. + ArrayRef getCallableResults(Region *region) { + assert(!isExternal() && region == &getBody() && "invalid callable"); + return getType().getResults(); + } + +private: + // This trait needs access to the hooks defined below. + friend class OpTrait::FunctionLike; + + /// Returns the number of arguments. This is a hook for OpTrait::FunctionLike. + unsigned getNumFuncArguments() { return getType().getInputs().size(); } + + /// Returns the number of results. This is a hook for OpTrait::FunctionLike. + unsigned getNumFuncResults() { return getType().getResults().size(); } + + /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' + /// attribute is present and checks if it holds a function type. Ensures + /// getType, getNumFuncArguments, and getNumFuncResults can be called safely. + LogicalResult verifyType() { + auto type = getTypeAttr().getValue(); + if (!type.isa()) + return emitOpError("requires '" + getTypeAttrName() + + "' attribute of function type"); + return success(); + } +}; +} // end namespace mlir + +namespace llvm { + +// Functions hash just like pointers. +template <> struct DenseMapInfo { + static mlir::FuncOp getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::FuncOp::getFromOpaquePointer(pointer); + } + static mlir::FuncOp getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::FuncOp::getFromOpaquePointer(pointer); + } + static unsigned getHashValue(mlir::FuncOp val) { + return hash_value(val.getAsOpaquePointer()); + } + static bool isEqual(mlir::FuncOp LHS, mlir::FuncOp RHS) { return LHS == RHS; } +}; + +/// Allow stealing the low bits of FuncOp. +template <> struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::FuncOp I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::FuncOp getFromVoidPointer(void *P) { + return mlir::FuncOp::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 3 }; +}; + +} // namespace llvm + +#endif // MLIR_IR_FUNCTION_H diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h new file mode 100644 index 0000000000000000000000000000000000000000..9d3e438f67e95be77a399604706a53007304c6cb --- /dev/null +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -0,0 +1,100 @@ +//===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utility functions for implementing function-like +// operations, in particular, parsing, printing and verification components +// common to function-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_FUNCTIONIMPLEMENTATION_H_ +#define MLIR_IR_FUNCTIONIMPLEMENTATION_H_ + +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { + +namespace impl { + +/// A named class for passing around the variadic flag. +class VariadicFlag { +public: + explicit VariadicFlag(bool variadic) : variadic(variadic) {} + bool isVariadic() const { return variadic; } + +private: + /// Underlying storage. + bool variadic; +}; + +/// Adds argument and result attributes, provided as `argAttrs` and +/// `resultAttrs` arguments, to the list of operation attributes in `result`. +/// Internally, argument and result attributes are stored as dict attributes +/// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef> argAttrs, + ArrayRef> resultAttrs); + +/// Callback type for `parseFunctionLikeOp`, the callback should produce the +/// type that will be associated with a function-like operation from lists of +/// function arguments and results, VariadicFlag indicates whether the function +/// should have variadic arguments; in case of error, it may populate the last +/// argument with a message. +using FuncTypeBuilder = function_ref, ArrayRef, VariadicFlag, std::string &)>; + +/// Parses a function signature using `parser`. The `allowVariadic` argument +/// indicates whether functions with variadic arguments are supported. The +/// trailing arguments are populated by this function with names, types and +/// attributes of the arguments and those of the results. +ParseResult parseFunctionSignature( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &argNames, + SmallVectorImpl &argTypes, + SmallVectorImpl> &argAttrs, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl> &resultAttrs); + +/// Parser implementation for function-like operations. Uses +/// `funcTypeBuilder` to construct the custom function type given lists of +/// input and output types. If `allowVariadic` is set, the parser will accept +/// trailing ellipsis in the function signature and indicate to the builder +/// whether the function is variadic. If the builder returns a null type, +/// `result` will not contain the `type` attribute. The caller can then add a +/// type, report the error or delegate the reporting to the op's verifier. +ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, + bool allowVariadic, + FuncTypeBuilder funcTypeBuilder); + +/// Printer implementation for function-like operations. Accepts lists of +/// argument and result types to use while printing. +void printFunctionLikeOp(OpAsmPrinter &p, Operation *op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes); + +/// Prints the signature of the function-like operation `op`. Assumes `op` has +/// the FunctionLike trait and passed the verification. +void printFunctionSignature(OpAsmPrinter &p, Operation *op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes); + +/// Prints the list of function prefixed with the "attributes" keyword. The +/// attributes with names listed in "elided" as well as those used by the +/// function-like operation internally are not printed. Nothing is printed +/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and +/// passed the verification. +void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, + unsigned numResults, + ArrayRef elided = {}); + +} // namespace impl + +} // namespace mlir + +#endif // MLIR_IR_FUNCTIONIMPLEMENTATION_H_ diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h new file mode 100644 index 0000000000000000000000000000000000000000..e6cba2c7404dac940e6e2f6a6b259e254e0aae56 --- /dev/null +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -0,0 +1,539 @@ +//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines support types for Operations that represent function-like +// constructs to use. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_FUNCTIONSUPPORT_H +#define MLIR_IR_FUNCTIONSUPPORT_H + +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/SmallString.h" + +namespace mlir { + +namespace impl { + +/// Return the name of the attribute used for function types. +inline StringRef getTypeAttrName() { return "type"; } + +/// Return the name of the attribute used for function arguments. +inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl &out) { + out.clear(); + return ("arg" + Twine(arg)).toStringRef(out); +} + +/// Return the name of the attribute used for function results. +inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl &out) { + out.clear(); + return ("result" + Twine(arg)).toStringRef(out); +} + +/// Returns the dictionary attribute corresponding to the argument at 'index'. +/// If there are no argument attributes at 'index', a null attribute is +/// returned. +inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) { + SmallString<8> nameOut; + return op->getAttrOfType(getArgAttrName(index, nameOut)); +} + +/// Returns the dictionary attribute corresponding to the result at 'index'. +/// If there are no result attributes at 'index', a null attribute is +/// returned. +inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) { + SmallString<8> nameOut; + return op->getAttrOfType(getResultAttrName(index, nameOut)); +} + +/// Return all of the attributes for the argument at 'index'. +inline ArrayRef getArgAttrs(Operation *op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : llvm::None; +} + +/// Return all of the attributes for the result at 'index'. +inline ArrayRef getResultAttrs(Operation *op, unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : llvm::None; +} + +} // namespace impl + +namespace OpTrait { + +/// This trait provides APIs for Ops that behave like functions. In particular: +/// - Ops must be symbols, i.e. also have the `Symbol` trait; +/// - Ops have a single region with multiple blocks that corresponds to the body +/// of the function; +/// - the absence of a region corresponds to an external function; +/// - leading arguments of the first block of the region are treated as function +/// arguments; +/// - they can have argument attributes that are stored in a dictionary +/// attribute on the Op itself. +/// This trait does *NOT* provide type support for the functions, meaning that +/// concrete Ops must handle the type of the declared or defined function. +/// `getTypeAttrName()` is a convenience function that returns the name of the +/// attribute that can be used to store the function type, but the trait makes +/// no assumption based on it. +/// +/// - Concrete ops *must* define a member function `getNumFuncArguments()` that +/// returns the number of function arguments based exclusively on type (so +/// that it can be called on function declarations). +/// - Concrete ops *must* define a member function `getNumFuncResults()` that +/// returns the number of function results based exclusively on type (so that +/// it can be called on function declarations). +/// - To verify that the type respects op-specific invariants, concrete ops may +/// redefine the `verifyType()` hook that will be called after verifying the +/// presence of the `type` attribute and before any call to +/// `getNumFuncArguments`/`getNumFuncResults` from the verifier. +/// - To verify that the body respects op-specific invariants, concrete ops may +/// redefine the `verifyBody()` hook that will be called after verifying the +/// function type and the presence of the (potentially empty) body region. +template +class FunctionLike : public OpTrait::TraitBase { +public: + /// Verify that all of the argument attributes are dialect attributes. + static LogicalResult verifyTrait(Operation *op); + + //===--------------------------------------------------------------------===// + // Body Handling + //===--------------------------------------------------------------------===// + + /// Returns true if this function is external, i.e. it has no body. + bool isExternal() { return empty(); } + + Region &getBody() { return this->getOperation()->getRegion(0); } + + /// Delete all blocks from this function. + void eraseBody() { + getBody().dropAllReferences(); + getBody().getBlocks().clear(); + } + + /// This is the list of blocks in the function. + using BlockListType = Region::BlockListType; + BlockListType &getBlocks() { return getBody().getBlocks(); } + + // Iteration over the block in the function. + using iterator = BlockListType::iterator; + using reverse_iterator = BlockListType::reverse_iterator; + + iterator begin() { return getBody().begin(); } + iterator end() { return getBody().end(); } + reverse_iterator rbegin() { return getBody().rbegin(); } + reverse_iterator rend() { return getBody().rend(); } + + bool empty() { return getBody().empty(); } + void push_back(Block *block) { getBody().push_back(block); } + void push_front(Block *block) { getBody().push_front(block); } + + Block &back() { return getBody().back(); } + Block &front() { return getBody().front(); } + + /// Hook for concrete ops to verify the contents of the body. Called as a + /// part of trait verification, after type verification and ensuring that a + /// region exists. + LogicalResult verifyBody(); + + //===--------------------------------------------------------------------===// + // Type Attribute Handling + //===--------------------------------------------------------------------===// + + /// Return the name of the attribute used for function types. + static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); } + + TypeAttr getTypeAttr() { + return this->getOperation()->template getAttrOfType( + getTypeAttrName()); + } + + bool isTypeAttrValid() { + auto typeAttr = getTypeAttr(); + if (!typeAttr) + return false; + return typeAttr.getValue() != Type{}; + } + + //===--------------------------------------------------------------------===// + // Argument Handling + //===--------------------------------------------------------------------===// + + unsigned getNumArguments() { + return static_cast(this)->getNumFuncArguments(); + } + + unsigned getNumResults() { + return static_cast(this)->getNumFuncResults(); + } + + /// Gets argument. + BlockArgument getArgument(unsigned idx) { + return getBlocks().front().getArgument(idx); + } + + // Supports non-const operand iteration. + using args_iterator = Block::args_iterator; + args_iterator args_begin() { return front().args_begin(); } + args_iterator args_end() { return front().args_end(); } + iterator_range getArguments() { + return {args_begin(), args_end()}; + } + + //===--------------------------------------------------------------------===// + // Argument Attributes + //===--------------------------------------------------------------------===// + + /// FunctionLike operations allow for attaching attributes to each of the + /// respective function arguments. These argument attributes are stored as + /// DictionaryAttrs in the main operation attribute dictionary. The name of + /// these entries is `arg` followed by the index of the argument. These + /// argument attribute dictionaries are optional, and will generally only + /// exist if they are non-empty. + + /// Return all of the attributes for the argument at 'index'. + ArrayRef getArgAttrs(unsigned index) { + return ::mlir::impl::getArgAttrs(this->getOperation(), index); + } + + /// Return all argument attributes of this function. + void getAllArgAttrs(SmallVectorImpl &result) { + for (unsigned i = 0, e = getNumArguments(); i != e; ++i) + result.emplace_back(getArgAttrDict(i)); + } + + /// Return the specified attribute, if present, for the argument at 'index', + /// null otherwise. + Attribute getArgAttr(unsigned index, Identifier name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + Attribute getArgAttr(unsigned index, StringRef name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getArgAttrOfType(unsigned index, Identifier name) { + return getArgAttr(index, name).template dyn_cast_or_null(); + } + template + AttrClass getArgAttrOfType(unsigned index, StringRef name) { + return getArgAttr(index, name).template dyn_cast_or_null(); + } + + /// Set the attributes held by the argument at 'index'. + void setArgAttrs(unsigned index, ArrayRef attributes); + void setArgAttrs(unsigned index, NamedAttributeList attributes); + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumArguments()); + for (unsigned i = 0, e = attributes.size(); i != e; ++i) + setArgAttrs(i, attributes[i]); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setArgAttr(unsigned index, Identifier name, Attribute value); + void setArgAttr(unsigned index, StringRef name, Attribute value) { + setArgAttr(index, Identifier::get(name, this->getOperation()->getContext()), + value); + } + + /// Remove the attribute 'name' from the argument at 'index'. + NamedAttributeList::RemoveResult removeArgAttr(unsigned index, + Identifier name); + + //===--------------------------------------------------------------------===// + // Result Attributes + //===--------------------------------------------------------------------===// + + /// FunctionLike operations allow for attaching attributes to each of the + /// respective function results. These result attributes are stored as + /// DictionaryAttrs in the main operation attribute dictionary. The name of + /// these entries is `result` followed by the index of the result. These + /// result attribute dictionaries are optional, and will generally only + /// exist if they are non-empty. + + /// Return all of the attributes for the result at 'index'. + ArrayRef getResultAttrs(unsigned index) { + return ::mlir::impl::getResultAttrs(this->getOperation(), index); + } + + /// Return all result attributes of this function. + void getAllResultAttrs(SmallVectorImpl &result) { + for (unsigned i = 0, e = getNumResults(); i != e; ++i) + result.emplace_back(getResultAttrDict(i)); + } + + /// Return the specified attribute, if present, for the result at 'index', + /// null otherwise. + Attribute getResultAttr(unsigned index, Identifier name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + Attribute getResultAttr(unsigned index, StringRef name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getResultAttrOfType(unsigned index, Identifier name) { + return getResultAttr(index, name).template dyn_cast_or_null(); + } + template + AttrClass getResultAttrOfType(unsigned index, StringRef name) { + return getResultAttr(index, name).template dyn_cast_or_null(); + } + + /// Set the attributes held by the result at 'index'. + void setResultAttrs(unsigned index, ArrayRef attributes); + void setResultAttrs(unsigned index, NamedAttributeList attributes); + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumResults()); + for (unsigned i = 0, e = attributes.size(); i != e; ++i) + setResultAttrs(i, attributes[i]); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setResultAttr(unsigned index, Identifier name, Attribute value); + void setResultAttr(unsigned index, StringRef name, Attribute value) { + setResultAttr(index, + Identifier::get(name, this->getOperation()->getContext()), + value); + } + + /// Remove the attribute 'name' from the result at 'index'. + NamedAttributeList::RemoveResult removeResultAttr(unsigned index, + Identifier name); + +protected: + /// Returns the attribute entry name for the set of argument attributes at + /// 'index'. + static StringRef getArgAttrName(unsigned index, SmallVectorImpl &out) { + return ::mlir::impl::getArgAttrName(index, out); + } + + /// Returns the dictionary attribute corresponding to the argument at 'index'. + /// If there are no argument attributes at 'index', a null attribute is + /// returned. + DictionaryAttr getArgAttrDict(unsigned index) { + assert(index < getNumArguments() && "invalid argument number"); + return ::mlir::impl::getArgAttrDict(this->getOperation(), index); + } + + /// Returns the attribute entry name for the set of result attributes at + /// 'index'. + static StringRef getResultAttrName(unsigned index, + SmallVectorImpl &out) { + return ::mlir::impl::getResultAttrName(index, out); + } + + /// Returns the dictionary attribute corresponding to the result at 'index'. + /// If there are no result attributes at 'index', a null attribute is + /// returned. + DictionaryAttr getResultAttrDict(unsigned index) { + assert(index < getNumResults() && "invalid result number"); + return ::mlir::impl::getResultAttrDict(this->getOperation(), index); + } + + /// Hook for concrete classes to verify that the type attribute respects + /// op-specific invariants. Default implementation always succeeds. + LogicalResult verifyType() { return success(); } +}; + +/// Default verifier checks that if the entry block exists, it has the same +/// number of arguments as the function-like operation. +template +LogicalResult FunctionLike::verifyBody() { + auto funcOp = cast(this->getOperation()); + + if (funcOp.isExternal()) + return success(); + + unsigned numArguments = funcOp.getNumArguments(); + if (funcOp.front().getNumArguments() != numArguments) + return funcOp.emitOpError("entry block must have ") + << numArguments << " arguments to match function signature"; + + return success(); +} + +template +LogicalResult FunctionLike::verifyTrait(Operation *op) { + MLIRContext *ctx = op->getContext(); + auto funcOp = cast(op); + + if (!funcOp.isTypeAttrValid()) + return funcOp.emitOpError("requires a type attribute '") + << getTypeAttrName() << '\''; + + if (failed(funcOp.verifyType())) + return failure(); + + for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) { + // Verify that all of the argument attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : funcOp.getArgAttrs(i)) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError("arguments may only have dialect attributes"); + auto dialectNamePair = attr.first.strref().split('.'); + if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, + /*argIndex=*/i, attr))) + return failure(); + } + } + } + + for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) { + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : funcOp.getResultAttrs(i)) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError("results may only have dialect attributes"); + auto dialectNamePair = attr.first.strref().split('.'); + if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return failure(); + } + } + } + + // Check that the op has exactly one region for the body. + if (op->getNumRegions() != 1) + return funcOp.emitOpError("expects one region"); + + return funcOp.verifyBody(); +} + +//===----------------------------------------------------------------------===// +// Function Argument Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the argument at 'index'. +template +void FunctionLike::setArgAttrs( + unsigned index, ArrayRef attributes) { + assert(index < getNumArguments() && "invalid argument number"); + SmallString<8> nameOut; + getArgAttrName(index, nameOut); + + if (attributes.empty()) + return (void)static_cast(this)->removeAttr(nameOut); + Operation *op = this->getOperation(); + op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); +} + +template +void FunctionLike::setArgAttrs(unsigned index, + NamedAttributeList attributes) { + assert(index < getNumArguments() && "invalid argument number"); + SmallString<8> nameOut; + if (auto newAttr = attributes.getDictionary()) + return this->getOperation()->setAttr(getArgAttrName(index, nameOut), + newAttr); + static_cast(this)->removeAttr(getArgAttrName(index, nameOut)); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void FunctionLike::setArgAttr(unsigned index, Identifier name, + Attribute value) { + auto curAttr = getArgAttrDict(index); + NamedAttributeList attrList(curAttr); + attrList.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (curAttr != attrList.getDictionary()) + setArgAttrs(index, attrList); +} + +/// Remove the attribute 'name' from the argument at 'index'. +template +NamedAttributeList::RemoveResult +FunctionLike::removeArgAttr(unsigned index, Identifier name) { + // Build an attribute list and remove the attribute at 'name'. + NamedAttributeList attrList(getArgAttrDict(index)); + auto result = attrList.remove(name); + + // If the attribute was removed, then update the argument dictionary. + if (result == NamedAttributeList::RemoveResult::Removed) + setArgAttrs(index, attrList); + return result; +} + +//===----------------------------------------------------------------------===// +// Function Result Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the result at 'index'. +template +void FunctionLike::setResultAttrs( + unsigned index, ArrayRef attributes) { + assert(index < getNumResults() && "invalid result number"); + SmallString<8> nameOut; + getResultAttrName(index, nameOut); + + if (attributes.empty()) + return (void)static_cast(this)->removeAttr(nameOut); + Operation *op = this->getOperation(); + op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); +} + +template +void FunctionLike::setResultAttrs(unsigned index, + NamedAttributeList attributes) { + assert(index < getNumResults() && "invalid result number"); + SmallString<8> nameOut; + if (auto newAttr = attributes.getDictionary()) + return this->getOperation()->setAttr(getResultAttrName(index, nameOut), + newAttr); + static_cast(this)->removeAttr( + getResultAttrName(index, nameOut)); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void FunctionLike::setResultAttr(unsigned index, Identifier name, + Attribute value) { + auto curAttr = getResultAttrDict(index); + NamedAttributeList attrList(curAttr); + attrList.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (curAttr != attrList.getDictionary()) + setResultAttrs(index, attrList); +} + +/// Remove the attribute 'name' from the result at 'index'. +template +NamedAttributeList::RemoveResult +FunctionLike::removeResultAttr(unsigned index, Identifier name) { + // Build an attribute list and remove the attribute at 'name'. + NamedAttributeList attrList(getResultAttrDict(index)); + auto result = attrList.remove(name); + + // If the attribute was removed, then update the result dictionary. + if (result == NamedAttributeList::RemoveResult::Removed) + setResultAttrs(index, attrList); + return result; +} + +} // end namespace OpTrait + +} // end namespace mlir + +#endif // MLIR_IR_FUNCTIONSUPPORT_H diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h new file mode 100644 index 0000000000000000000000000000000000000000..604eebf341e4f1de8ca2c50fcb8346b285756fec --- /dev/null +++ b/mlir/include/mlir/IR/Identifier.h @@ -0,0 +1,134 @@ +//===- Identifier.h - MLIR Identifier Class ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_IDENTIFIER_H +#define MLIR_IR_IDENTIFIER_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +class MLIRContext; + +/// This class represents a uniqued string owned by an MLIRContext. Strings +/// represented by this type cannot contain nul characters, and may not have a +/// zero length. +/// +/// This is a POD type with pointer size, so it should be passed around by +/// value. The underlying data is owned by MLIRContext and is thus immortal for +/// almost all clients. +class Identifier { +public: + /// Return an identifier for the specified string. + static Identifier get(StringRef str, MLIRContext *context); + Identifier(const Identifier &) = default; + Identifier &operator=(const Identifier &other) = default; + + /// Return a StringRef for the string. + StringRef strref() const { return StringRef(pointer, size()); } + + /// Identifiers implicitly convert to StringRefs. + operator StringRef() const { return strref(); } + + /// Return an std::string. + std::string str() const { return strref().str(); } + + /// Return a null terminated C string. + const char *c_str() const { return pointer; } + + /// Return a pointer to the start of the string data. + const char *data() const { return pointer; } + + /// Return the number of bytes in this string. + unsigned size() const { return ::strlen(pointer); } + + /// Return true if this identifier is the specified string. + bool is(StringRef string) const { return strref().equals(string); } + + const char *begin() const { return pointer; } + const char *end() const { return pointer + size(); } + + void print(raw_ostream &os) const; + void dump() const; + + const void *getAsOpaquePointer() const { + return static_cast(pointer); + } + static Identifier getFromOpaquePointer(const void *pointer) { + return Identifier((const char *)pointer); + } + +private: + /// These are the bytes of the string, which is a nul terminated string. + const char *pointer; + explicit Identifier(const char *pointer) : pointer(pointer) {} +}; + +inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) { + identifier.print(os); + return os; +} + +inline bool operator==(Identifier lhs, Identifier rhs) { + return lhs.data() == rhs.data(); +} + +inline bool operator!=(Identifier lhs, Identifier rhs) { + return lhs.data() != rhs.data(); +} + +inline bool operator==(Identifier lhs, StringRef rhs) { return lhs.is(rhs); } +inline bool operator!=(Identifier lhs, StringRef rhs) { return !lhs.is(rhs); } +inline bool operator==(StringRef lhs, Identifier rhs) { return rhs.is(lhs); } +inline bool operator!=(StringRef lhs, Identifier rhs) { return !rhs.is(lhs); } + +// Make identifiers hashable. +inline llvm::hash_code hash_value(Identifier arg) { + return llvm::hash_value(arg.strref()); +} + +} // end namespace mlir + +namespace llvm { +// Identifiers hash just like pointers, there is no need to hash the bytes. +template <> +struct DenseMapInfo { + static mlir::Identifier getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::Identifier::getFromOpaquePointer(pointer); + } + static mlir::Identifier getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::Identifier::getFromOpaquePointer(pointer); + } + static unsigned getHashValue(mlir::Identifier Val) { + return DenseMapInfo::getHashValue(Val.data()); + } + static bool isEqual(mlir::Identifier LHS, mlir::Identifier RHS) { + return LHS == RHS; + } +}; + +/// The pointer inside of an identifier comes from a StringMap, so its alignment +/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to +/// steal the low bits. +template <> +struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::Identifier I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::Identifier getFromVoidPointer(void *P) { + return mlir::Identifier::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 2 }; +}; + +} // end namespace llvm +#endif diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h new file mode 100644 index 0000000000000000000000000000000000000000..1238511df34cff9f177f690c78c0fad62b9ff5a2 --- /dev/null +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -0,0 +1,142 @@ +//===- IntegerSet.h - MLIR Integer Set Class --------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Integer sets are sets of points from the integer lattice constrained by +// affine equality/inequality constraints. This class is meant to represent +// integer sets in the IR - for 'affine.if' operations and as attributes of +// other operations. It is typically expected to contain only a handful of +// affine constraints, and is immutable like an affine map. Integer sets are not +// unique'd - although affine expressions that make up its equalities and +// inequalities are themselves unique. + +// This class is not meant for affine analysis and operations like set +// operations, emptiness checks, or other math operations for analysis and +// transformation. For the latter, use FlatAffineConstraints. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_INTEGER_SET_H +#define MLIR_IR_INTEGER_SET_H + +#include "mlir/IR/AffineExpr.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +namespace detail { +struct IntegerSetStorage; +} + +class MLIRContext; + +/// An integer set representing a conjunction of one or more affine equalities +/// and inequalities. An integer set in the IR is immutable like the affine map, +/// but integer sets are not unique'd. The affine expressions that make up the +/// equalities and inequalities of an integer set are themselves unique and are +/// allocated by the bump pointer allocator. +class IntegerSet { +public: + using ImplType = detail::IntegerSetStorage; + + IntegerSet() : set(nullptr) {} + explicit IntegerSet(ImplType *set) : set(set) {} + IntegerSet(const IntegerSet &other) : set(other.set) {} + IntegerSet &operator=(const IntegerSet &other) = default; + + static IntegerSet get(unsigned dimCount, unsigned symbolCount, + ArrayRef constraints, + ArrayRef eqFlags); + + // Returns the canonical empty IntegerSet (i.e. a set with no integer points). + static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols, + MLIRContext *context) { + auto one = getAffineConstantExpr(1, context); + /* 1 == 0 */ + return get(numDims, numSymbols, one, true); + } + + /// Returns true if this is the canonical integer set. + bool isEmptyIntegerSet() const; + + /// This method substitutes any uses of dimensions and symbols (e.g. + /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified + /// integer set. Because this can be used to eliminate dims and + /// symbols, the client needs to specify the number of dims and symbols in + /// the result. The returned map always has the same number of results. + IntegerSet replaceDimsAndSymbols(ArrayRef dimReplacements, + ArrayRef symReplacements, + unsigned numResultDims, + unsigned numResultSyms); + + explicit operator bool() { return set; } + bool operator==(IntegerSet other) const { return set == other.set; } + + unsigned getNumDims() const; + unsigned getNumSymbols() const; + unsigned getNumInputs() const; + unsigned getNumConstraints() const; + unsigned getNumEqualities() const; + unsigned getNumInequalities() const; + + ArrayRef getConstraints() const; + + AffineExpr getConstraint(unsigned idx) const; + + /// Returns the equality bits, which specify whether each of the constraints + /// is an equality or inequality. + ArrayRef getEqFlags() const; + + /// Returns true if the idx^th constraint is an equality, false if it is an + /// inequality. + bool isEq(unsigned idx) const; + + MLIRContext *getContext() const; + + /// Walk all of the AffineExpr's in this set's constraints. Each node in an + /// expression tree is visited in postorder. + void walkExprs(function_ref callback) const; + + void print(raw_ostream &os) const; + void dump() const; + + friend ::llvm::hash_code hash_value(IntegerSet arg); + +private: + ImplType *set; + /// Sets with constraints fewer than kUniquingThreshold are uniqued. + constexpr static unsigned kUniquingThreshold = 4; +}; + +// Make AffineExpr hashable. +inline ::llvm::hash_code hash_value(IntegerSet arg) { + return ::llvm::hash_value(arg.set); +} + +} // end namespace mlir +namespace llvm { + +// IntegerSet hash just like pointers +template <> struct DenseMapInfo { + static mlir::IntegerSet getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::IntegerSet(static_cast(pointer)); + } + static mlir::IntegerSet getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::IntegerSet(static_cast(pointer)); + } + static unsigned getHashValue(mlir::IntegerSet val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::IntegerSet LHS, mlir::IntegerSet RHS) { + return LHS == RHS; + } +}; + +} // namespace llvm +#endif // MLIR_IR_INTEGER_SET_H diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h new file mode 100644 index 0000000000000000000000000000000000000000..c36bcb3073541f00786962d9ac196a1cd6a4909f --- /dev/null +++ b/mlir/include/mlir/IR/Location.h @@ -0,0 +1,332 @@ +//===- Location.h - MLIR Location Classes -----------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These classes provide the ability to relate MLIR objects back to source +// location position information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_LOCATION_H +#define MLIR_IR_LOCATION_H + +#include "mlir/IR/Attributes.h" + +namespace mlir { + +class Attribute; +class MLIRContext; +class Identifier; + +namespace detail { + +struct CallSiteLocationStorage; +struct FileLineColLocationStorage; +struct FusedLocationStorage; +struct LocationStorage; +struct NameLocationStorage; +struct OpaqueLocationStorage; +struct UnknownLocationStorage; + +} // namespace detail + +/// Location objects represent source locations information in MLIR. +/// LocationAttr acts as the anchor for all Location based attributes. +class LocationAttr : public Attribute { +public: + using Attribute::Attribute; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(Attribute attr) { + return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR && + attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR; + } +}; + +/// This class defines the main interface for locations in MLIR and acts as a +/// non-nullable wrapper around a LocationAttr. +class Location { +public: + Location(LocationAttr loc) : impl(loc) { + assert(loc && "location should never be null."); + } + + /// Access the impl location attribute. + operator LocationAttr() const { return impl; } + LocationAttr *operator->() const { return const_cast(&impl); } + + /// Type casting utilities on the underlying location. + template bool isa() const { return impl.isa(); } + template U dyn_cast() const { return impl.dyn_cast(); } + template U cast() const { return impl.cast(); } + + /// Comparison operators. + bool operator==(Location rhs) const { return impl == rhs.impl; } + bool operator!=(Location rhs) const { return !(*this == rhs); } + + /// Print the location. + void print(raw_ostream &os) const { impl.print(os); } + void dump() const { impl.dump(); } + + friend ::llvm::hash_code hash_value(Location arg); + + /// Methods for supporting PointerLikeTypeTraits. + const void *getAsOpaquePointer() const { return impl.getAsOpaquePointer(); } + static Location getFromOpaquePointer(const void *pointer) { + return LocationAttr(reinterpret_cast(pointer)); + } + +protected: + /// The internal backing location attribute. + LocationAttr impl; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const Location &loc) { + loc.print(os); + return os; +} + +/// Represents a location as call site. "callee" is the concrete location +/// (Unknown/NameLocation/FileLineColLoc/OpaqueLoc) and "caller" points to the +/// caller's location (another CallLocation or a concrete location). Multiple +/// CallSiteLocs can be chained to form a call stack. +class CallSiteLoc + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Return a uniqued call location object. + static Location get(Location callee, Location caller); + + /// Return a call site location which represents a name reference in one line + /// or a stack of frames. The input frames are ordered from innermost to + /// outermost. + static Location get(Location name, ArrayRef frames); + + /// The concrete location information this object presents. + Location getCallee() const; + + /// The caller's location. + Location getCaller() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::CallSiteLocation; + } +}; + +/// Represents a location derived from a file/line/column location. The column +/// and line may be zero to represent unknown column and/or unknown line/column +/// information. +class FileLineColLoc + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Return a uniqued FileLineCol location object. + static Location get(Identifier filename, unsigned line, unsigned column, + MLIRContext *context); + static Location get(StringRef filename, unsigned line, unsigned column, + MLIRContext *context); + + StringRef getFilename() const; + + unsigned getLine() const; + unsigned getColumn() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::FileLineColLocation; + } +}; + +/// Represents a value composed of multiple source constructs, with an optional +/// metadata attribute. +class FusedLoc : public Attribute::AttrBase { +public: + using Base::Base; + + /// Return a uniqued Fused Location object. The first location in the list + /// will get precedence during diagnostic emission, with the rest being + /// displayed as supplementary "fused from here" style notes. + static Location get(ArrayRef locs, Attribute metadata, + MLIRContext *context); + static Location get(ArrayRef locs, MLIRContext *context) { + return get(locs, Attribute(), context); + } + + ArrayRef getLocations() const; + + /// Returns the optional metadata attached to this fused location. Given that + /// it is optional, the return value may be a null node. + Attribute getMetadata() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::FusedLocation; + } +}; + +/// Represents an identity name attached to a child location. +class NameLoc : public Attribute::AttrBase { +public: + using Base::Base; + + /// Return a uniqued name location object. The child location must not be + /// another NameLoc. + static Location get(Identifier name, Location child); + + /// Return a uniqued name location object with an unknown child. + static Location get(Identifier name, MLIRContext *context); + + /// Return the name identifier. + Identifier getName() const; + + /// Return the child location. + Location getChildLoc() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::NameLocation; + } +}; + +/// Represents an unknown location. This is always a singleton for a given +/// MLIRContext. +class UnknownLoc : public Attribute::AttrBase { +public: + using Base::Base; + + /// Get an instance of the UnknownLoc. + static Location get(MLIRContext *context); + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::UnknownLocation; + } +}; + +/// Represents a location that is external to MLIR. Contains a pointer to some +/// data structure and an optional location that can be used if the first one is +/// not suitable. Since it contains an external structure, only optional +/// location is used during serialization. +/// The class also provides a number of methods for making type-safe casts +/// between a pointer to an object and opaque location. +class OpaqueLoc : public Attribute::AttrBase { +public: + using Base::Base; + + /// Returns an instance of opaque location which contains a given pointer to + /// an object. The corresponding MLIR location is set to UnknownLoc. + template + static Location get(T underlyingLocation, MLIRContext *context) { + return get(reinterpret_cast(underlyingLocation), + ClassID::getID(), UnknownLoc::get(context)); + } + + /// Returns an instance of opaque location which contains a given pointer to + /// an object and an additional MLIR location. + template + static Location get(T underlyingLocation, Location fallbackLocation) { + return get(reinterpret_cast(underlyingLocation), + ClassID::getID(), fallbackLocation); + } + + /// Returns a pointer to some data structure that opaque location stores. + template static T getUnderlyingLocation(Location location) { + assert(isa(location)); + return reinterpret_cast( + location.cast().getUnderlyingLocation()); + } + + /// Returns a pointer to some data structure that opaque location stores. + /// Returns nullptr if provided location is not opaque location or if it + /// contains a pointer of different type. + template + static T getUnderlyingLocationOrNull(Location location) { + return isa(location) + ? reinterpret_cast( + location.cast().getUnderlyingLocation()) + : T(nullptr); + } + + /// Checks whether provided location is opaque location and contains a pointer + /// to an object of particular type. + template static bool isa(Location location) { + auto opaque_loc = location.dyn_cast(); + return opaque_loc && opaque_loc.getClassId() == ClassID::getID(); + } + + /// Returns a pointer to the corresponding object. + uintptr_t getUnderlyingLocation() const; + + /// Returns a ClassID* that represents the underlying objects c++ type. + ClassID *getClassId() const; + + /// Returns a fallback location. + Location getFallbackLocation() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::OpaqueLocation; + } + +private: + static Location get(uintptr_t underlyingLocation, ClassID *classID, + Location fallbackLocation); +}; + +// Make Location hashable. +inline ::llvm::hash_code hash_value(Location arg) { + return hash_value(arg.impl); +} + +} // end namespace mlir + +namespace llvm { + +// Type hash just like pointers. +template <> struct DenseMapInfo { + static mlir::Location getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::Location::getFromOpaquePointer(pointer); + } + static mlir::Location getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::Location::getFromOpaquePointer(pointer); + } + static unsigned getHashValue(mlir::Location val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::Location LHS, mlir::Location RHS) { + return LHS == RHS; + } +}; + +/// We align LocationStorage by 8, so allow LLVM to steal the low bits. +template <> struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::Location I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::Location getFromVoidPointer(void *P) { + return mlir::Location::getFromOpaquePointer(P); + } + enum { + NumLowBitsAvailable = + PointerLikeTypeTraits::NumLowBitsAvailable + }; +}; + +} // namespace llvm + +#endif diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h new file mode 100644 index 0000000000000000000000000000000000000000..e0761bcaaf13546b0f67f6c7da6efd579c8dfbbe --- /dev/null +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -0,0 +1,83 @@ +//===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MLIRCONTEXT_H +#define MLIR_IR_MLIRCONTEXT_H + +#include "mlir/Support/LLVM.h" +#include +#include +#include + +namespace mlir { +class AbstractOperation; +class DiagnosticEngine; +class Dialect; +class InFlightDiagnostic; +class Location; +class MLIRContextImpl; +class StorageUniquer; + +/// MLIRContext is the top-level object for a collection of MLIR modules. It +/// holds immortal uniqued objects like types, and the tables used to unique +/// them. +/// +/// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with +/// a very generic name ("Context") and because it is uncommon for clients to +/// interact with it. +/// +class MLIRContext { +public: + explicit MLIRContext(); + ~MLIRContext(); + + /// Return information about all registered IR dialects. + std::vector getRegisteredDialects(); + + /// Get a registered IR dialect with the given namespace. If an exact match is + /// not found, then return nullptr. + Dialect *getRegisteredDialect(StringRef name); + + /// Get a registered IR dialect for the given derived dialect type. The + /// derived type must provide a static 'getDialectNamespace' method. + template T *getRegisteredDialect() { + return static_cast(getRegisteredDialect(T::getDialectNamespace())); + } + + /// Return information about all registered operations. This isn't very + /// efficient: typically you should ask the operations about their properties + /// directly. + std::vector getRegisteredOperations(); + + // This is effectively private given that only MLIRContext.cpp can see the + // MLIRContextImpl type. + MLIRContextImpl &getImpl() { return *impl; } + + /// Returns the diagnostic engine for this context. + DiagnosticEngine &getDiagEngine(); + + /// Returns the storage uniquer used for creating affine constructs. + StorageUniquer &getAffineUniquer(); + + /// Returns the storage uniquer used for constructing type storage instances. + /// This should not be used directly. + StorageUniquer &getTypeUniquer(); + + /// Returns the storage uniquer used for constructing attribute storage + /// instances. This should not be used directly. + StorageUniquer &getAttributeUniquer(); + +private: + const std::unique_ptr impl; + + MLIRContext(const MLIRContext &) = delete; + void operator=(const MLIRContext &) = delete; +}; +} // end namespace mlir + +#endif // MLIR_IR_MLIRCONTEXT_H diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h new file mode 100644 index 0000000000000000000000000000000000000000..2cfa2428bd590e4e519e99f7f9e53003cbf24953 --- /dev/null +++ b/mlir/include/mlir/IR/Matchers.h @@ -0,0 +1,261 @@ +//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a simple and efficient mechanism for performing general +// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's +// include/llvm/IR/PatternMatch.h. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_MATCHERS_H +#define MLIR_MATCHERS_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +namespace detail { + +/// The matcher that matches a certain kind of Attribute and binds the value +/// inside the Attribute. +template < + typename AttrClass, + // Require AttrClass to be a derived class from Attribute and get its + // value type + typename ValueType = + typename std::enable_if::value, + AttrClass>::type::ValueType, + // Require the ValueType is not void + typename = typename std::enable_if::value>::type> +struct attr_value_binder { + ValueType *bind_value; + + /// Creates a matcher instance that binds the value to bv if match succeeds. + attr_value_binder(ValueType *bv) : bind_value(bv) {} + + bool match(const Attribute &attr) { + if (auto intAttr = attr.dyn_cast()) { + *bind_value = intAttr.getValue(); + return true; + } + return false; + } +}; + +/// The matcher that matches a constant foldable operation that has no side +/// effect, no operands and produces a single result. +template struct constant_op_binder { + AttrT *bind_value; + + /// Creates a matcher instance that binds the constant attribute value to + /// bind_value if match succeeds. + constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} + + bool match(Operation *op) { + if (op->getNumOperands() > 0 || op->getNumResults() != 1) + return false; + if (!op->hasNoSideEffect()) + return false; + + SmallVector foldedOp; + if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) { + if (auto attr = foldedOp.front().dyn_cast()) { + if ((*bind_value = attr.dyn_cast())) + return true; + } + } + return false; + } +}; + +/// The matcher that matches a constant scalar / vector splat / tensor splat +/// integer operation and binds the constant integer value. +struct constant_int_op_binder { + IntegerAttr::ValueType *bind_value; + + /// Creates a matcher instance that binds the value to bv if match succeeds. + constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} + + bool match(Operation *op) { + Attribute attr; + if (!constant_op_binder(&attr).match(op)) + return false; + auto type = op->getResult(0)->getType(); + + if (type.isIntOrIndex()) { + return attr_value_binder(bind_value).match(attr); + } + if (type.isa() || type.isa()) { + if (auto splatAttr = attr.dyn_cast()) { + return attr_value_binder(bind_value) + .match(splatAttr.getSplatValue()); + } + } + return false; + } +}; + +/// The matcher that matches a given target constant scalar / vector splat / +/// tensor splat integer value. +template struct constant_int_value_matcher { + bool match(Operation *op) { + APInt value; + return constant_int_op_binder(&value).match(op) && TargetValue == value; + } +}; + +/// The matcher that matches anything except the given target constant scalar / +/// vector splat / tensor splat integer value. +template struct constant_int_not_value_matcher { + bool match(Operation *op) { + APInt value; + return constant_int_op_binder(&value).match(op) && TargetNotValue != value; + } +}; + +/// The matcher that matches a certain kind of op. +template struct op_matcher { + bool match(Operation *op) { return isa(op); } +}; + +/// Trait to check whether T provides a 'match' method with type +/// `OperationOrValue`. +template +using has_operation_or_value_matcher_t = + decltype(std::declval().match(std::declval())); + +/// Statically switch to a Value matcher. +template +typename std::enable_if_t::value, + bool> +matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { + return matcher.match(op->getOperand(idx)); +} + +/// Statically switch to an Operation matcher. +template +typename std::enable_if_t::value, + bool> +matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { + if (auto defOp = op->getOperand(idx)->getDefiningOp()) + return matcher.match(defOp); + return false; +} + +/// Terminal matcher, always returns true. +struct AnyValueMatcher { + bool match(Value op) const { return true; } +}; + +/// Binds to a specific value and matches it. +struct PatternMatcherValue { + PatternMatcherValue(Value val) : value(val) {} + bool match(Value val) const { return val == value; } + Value value; +}; + +template +constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, + std::index_sequence) { + (void)std::initializer_list{ + 0, + (callback(std::integral_constant{}, std::get(tuple)), + 0)...}; +} + +template +constexpr void enumerate(std::tuple &tuple, CallbackT &&callback) { + detail::enumerateImpl(tuple, std::forward(callback), + std::make_index_sequence{}); +} + +/// RecursivePatternMatcher that composes. +template +struct RecursivePatternMatcher { + RecursivePatternMatcher(OperandMatchers... matchers) + : operandMatchers(matchers...) {} + bool match(Operation *op) { + if (!isa(op) || op->getNumOperands() != sizeof...(OperandMatchers)) + return false; + bool res = true; + enumerate(operandMatchers, [&](size_t index, auto &matcher) { + res &= matchOperandOrValueAtIndex(op, index, matcher); + }); + return res; + } + std::tuple operandMatchers; +}; + +} // end namespace detail + +/// Matches a value from a constant foldable operation and writes the value to +/// bind_value. +template +inline detail::constant_op_binder m_Constant(AttrT *bind_value) { + return detail::constant_op_binder(bind_value); +} + +/// Matches a constant scalar / vector splat / tensor splat integer one. +inline detail::constant_int_value_matcher<1> m_One() { + return detail::constant_int_value_matcher<1>(); +} + +/// Matches the given OpClass. +template inline detail::op_matcher m_Op() { + return detail::op_matcher(); +} + +/// Matches a constant scalar / vector splat / tensor splat integer zero. +inline detail::constant_int_value_matcher<0> m_Zero() { + return detail::constant_int_value_matcher<0>(); +} + +/// Matches a constant scalar / vector splat / tensor splat integer that is any +/// non-zero value. +inline detail::constant_int_not_value_matcher<0> m_NonZero() { + return detail::constant_int_not_value_matcher<0>(); +} + +/// Entry point for matching a pattern over a Value. +template +inline bool matchPattern(Value value, const Pattern &pattern) { + // TODO: handle other cases + if (auto *op = value->getDefiningOp()) + return const_cast(pattern).match(op); + return false; +} + +/// Entry point for matching a pattern over an Operation. +template +inline bool matchPattern(Operation *op, const Pattern &pattern) { + return const_cast(pattern).match(op); +} + +/// Matches a constant holding a scalar/vector/tensor integer (splat) and +/// writes the integer value to bind_value. +inline detail::constant_int_op_binder +m_ConstantInt(IntegerAttr::ValueType *bind_value) { + return detail::constant_int_op_binder(bind_value); +} + +template +auto m_Op(Matchers... matchers) { + return detail::RecursivePatternMatcher(matchers...); +} + +namespace matchers { +inline auto m_Any() { return detail::AnyValueMatcher(); } +inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } +} // namespace matchers + +} // end namespace mlir + +#endif // MLIR_MATCHERS_H diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h new file mode 100644 index 0000000000000000000000000000000000000000..babc51aad0d8b983932758c9a63c4a3791e78981 --- /dev/null +++ b/mlir/include/mlir/IR/Module.h @@ -0,0 +1,167 @@ +//===- Module.h - MLIR Module Class -----------------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Module is the top-level container for code in an MLIR program. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MODULE_H +#define MLIR_IR_MODULE_H + +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +class ModuleTerminatorOp; + +//===----------------------------------------------------------------------===// +// Module Operation. +//===----------------------------------------------------------------------===// + +/// ModuleOp represents a module, or an operation containing one region with a +/// single block containing opaque operations. The region of a module is not +/// allowed to implicitly capture global values, and all external references +/// must use symbolic references via attributes(e.g. via a string name). +class ModuleOp + : public Op< + ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult, + OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable, + OpTrait::SingleBlockImplicitTerminator::Impl> { +public: + using Op::Op; + using Op::print; + + static StringRef getOperationName() { return "module"; } + + static void build(Builder *builder, OperationState &result, + Optional name = llvm::None); + + /// Construct a module from the given location with an optional name. + static ModuleOp create(Location loc, Optional name = llvm::None); + + /// Operation hooks. + static ParseResult parse(OpAsmParser &parser, OperationState &result); + void print(OpAsmPrinter &p); + LogicalResult verify(); + + /// Return body of this module. + Region &getBodyRegion(); + Block *getBody(); + + /// Return the name of this module if present. + Optional getName(); + + /// Print the this module in the custom top-level form. + void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); + void dump(); + + //===--------------------------------------------------------------------===// + // Body Management. + //===--------------------------------------------------------------------===// + + /// Iteration over the operations in the module. + using iterator = Block::iterator; + + iterator begin() { return getBody()->begin(); } + iterator end() { return getBody()->end(); } + Operation &front() { return *begin(); } + + /// This returns a range of operations of the given type 'T' held within the + /// module. + template iterator_range> getOps() { + return getBody()->getOps(); + } + + /// Insert the operation into the back of the body, before the terminator. + void push_back(Operation *op) { + insert(Block::iterator(getBody()->getTerminator()), op); + } + + /// Insert the operation at the given insertion point. Note: The operation is + /// never inserted after the terminator, even if the insertion point is end(). + void insert(Operation *insertPt, Operation *op) { + insert(Block::iterator(insertPt), op); + } + void insert(Block::iterator insertPt, Operation *op) { + auto *body = getBody(); + if (insertPt == body->end()) + insertPt = Block::iterator(body->getTerminator()); + body->getOperations().insert(insertPt, op); + } +}; + +/// The ModuleTerminatorOp is a special terminator operation for the body of a +/// ModuleOp, it has no semantic meaning beyond keeping the body of a ModuleOp +/// well-formed. +/// +/// This operation does _not_ have a custom syntax. However, ModuleOp will omit +/// the terminator in their custom syntax for brevity. +class ModuleTerminatorOp + : public Op::Impl, OpTrait::IsTerminator> { +public: + using Op::Op; + static StringRef getOperationName() { return "module_terminator"; } + static void build(Builder *, OperationState &) {} +}; + +/// This class acts as an owning reference to a module, and will automatically +/// destroy the held module if valid. +class OwningModuleRef { +public: + OwningModuleRef(std::nullptr_t = nullptr) {} + OwningModuleRef(ModuleOp module) : module(module) {} + OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {} + ~OwningModuleRef() { + if (module) + module.erase(); + } + + // Assign from another module reference. + OwningModuleRef &operator=(OwningModuleRef &&other) { + if (module) + module.erase(); + module = other.release(); + return *this; + } + + /// Allow accessing the internal module. + ModuleOp get() const { return module; } + ModuleOp operator*() const { return module; } + ModuleOp *operator->() { return &module; } + explicit operator bool() const { return module; } + + /// Release the referenced module. + ModuleOp release() { + ModuleOp released; + std::swap(released, module); + return released; + } + +private: + ModuleOp module; +}; + +} // end namespace mlir + +namespace llvm { + +/// Allow stealing the low bits of ModuleOp. +template <> struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::ModuleOp I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::ModuleOp getFromVoidPointer(void *P) { + return mlir::ModuleOp::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 3 }; +}; + +} // end namespace llvm + +#endif // MLIR_IR_MODULE_H diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td new file mode 100644 index 0000000000000000000000000000000000000000..7e31c07575e1b35ab18ea2044420e7c52c7a054f --- /dev/null +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -0,0 +1,54 @@ +//===- OpAsmInterface.td - Asm Interfaces for opse ---------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains Interfaces for interacting with the AsmParser and +// AsmPrinter. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_OPASMINTERFACE +#define MLIR_OPASMINTERFACE + +include "mlir/IR/OpBase.td" + +/// Interface for hooking into the OpAsmPrinter and OpAsmParser. +def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> { + let description = [{ + This interface provides hooks to interact with the AsmPrinter and AsmParser + classes. + }]; + + let methods = [ + InterfaceMethod<[{ + Get a special name to use when printing the results of this operation. + The given callback is invoked with a specific result value that starts a + result "pack", and the name to give this result pack. To signal that a + result pack should use the default naming scheme, a None can be passed + in instead of the name. + + For example, if you have an operation that has four results and you want + to split these into three distinct groups you could do the following: + + ```c++ + setNameFn(getResult(0), "first_result"); + setNameFn(getResult(1), "middle_results"); + setNameFn(getResult(3), ""); // use the default numbering. + ``` + + This would print the operation as follows: + + ```mlir + %first_result, %middle_results:2, %0 = "my.op" ... + ``` + }], + "void", "getAsmResultNames", (ins "OpAsmSetValueNameFn":$setNameFn) + >, + ]; +} + +#endif // MLIR_OPASMINTERFACE diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td new file mode 100644 index 0000000000000000000000000000000000000000..c457d25fc51ae6313780598183e004d08101a630 --- /dev/null +++ b/mlir/include/mlir/IR/OpBase.td @@ -0,0 +1,1872 @@ +//===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef OP_BASE +#define OP_BASE + +//===----------------------------------------------------------------------===// +// Common utilities for defining TableGen mechanisms +//===----------------------------------------------------------------------===// + +// A workaround for the inability to define functions in Tablegen. +// +// The template parameter defines a string that can be extracted from an +// instance of this class by accessing the "result" member. Subclasses can take +// their own template parameters as function "arguments" and use them to +// populate result. +// For example, if it didn't already exist, a concat function could be defined +// like: +// +// class StrConcat strings> : +// StrFunc +// +// and then called like +// +// StrConcat<["a", "b", "c"]>.result +// +// to get the string "abc" +class StrFunc { + string result = r; +} + +// Concatenates a list of strings with a separator (default ", ") +class StrJoin strings, string sep = ", "> : + StrFunc; + +// Concatenates a list of integers into a string with a separator (default ", ") +class StrJoinInt integers, string sep = ", "> : + StrJoin(i)), sep>; + +//===----------------------------------------------------------------------===// +// Predicate definitions +//===----------------------------------------------------------------------===// + +// Base class for logical predicates. +// +// Predicates are used to compose constraints (see next section for details). +// There are two categories of predicates: +// +// 1. CPred: the primitive leaf predicate. +// 2. Compound predicate: a predicate composed from child predicates using +// predicate combiners ("conjunction", "disjunction", "negation" or +// "substitution"). +class Pred; + +// A logical predicate wrapping any C expression. +// +// This is the basis for composing more complex predicates. It is the "atom" +// predicate from the perspective of TableGen and the "interface" between +// TableGen and C++. What is inside is already C++ code, which will be treated +// as opaque strings with special placeholders to be substituted. +// +// ## Special placeholders +// +// Special placeholders can be used to refer to entities in the context where +// this predicate is used. They serve as "hooks" to the enclosing environment. +// The following special placeholders are supported in constraints for an op: +// +// * `$_builder` will be replaced by a mlir::Builder instance. +// * `$_op` will be replaced by the current operation. +// * `$_self` will be replaced with the entity this predicate is attached to. +// E.g., `BoolAttr` is an attribute constraint that wraps a +// `CPred<"$_self.isa()">` (see the following sections for details). +// Then for `F32:$attr`,`$_self` will be replaced by `$attr`. +// For type constraints, it's a little bit special since we want the +// constraints on each type definition reads naturally and we want to attach +// type constraints directly to an operand/result, $_self will be replaced +// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its +// `$_self` will be expanded as `getOperand(...)->getType()`. +class CPred : Pred { + code predExpr = "(" # pred # ")"; +} + +// Kinds of predicate combiners. These must closely match the predicates +// implemented by the C++ backend (tblgen::PredCombinerKind). +class PredCombinerKind; +def PredCombinerAnd : PredCombinerKind; +def PredCombinerOr : PredCombinerKind; +def PredCombinerNot : PredCombinerKind; +def PredCombinerSubstLeaves : PredCombinerKind; +def PredCombinerConcat : PredCombinerKind; + +// A predicate that combines other predicates as defined by PredCombinerKind. +// Instantiated below. +class CombinedPred c> : Pred { + PredCombinerKind kind = k; + list children = c; +} + +// Predicate combiners + +// A predicate that holds if all of its children hold. Always holds for zero +// children. +class And children> : CombinedPred; + +// A predicate that holds if any of its children hold. Never holds for zero +// children. +class Or children> : CombinedPred; + +// A predicate that holds if its child does not. +class Neg : CombinedPred; + +// A predicate that substitutes "pat" with "repl" in predicate calls of the +// leaves of the predicate tree (i.e., not CombinedPred). +// +// This is plain string substitution without regular expressions or captures. +// New predicates with more complex logical can be introduced should the need +// arise. +class SubstLeaves + : CombinedPred { + string pattern = pat; + string replacement = repl; +} + +// A predicate that prepends `pre` and appends `suf` to the final predicate +// string composed from `child`. This is plain string concatenation and there +// will be no substitution happening for `pre` and `suf`. +class Concat : + CombinedPred { + string prefix = pre; + string suffix = suf; +} + +//===----------------------------------------------------------------------===// +// Constraint definitions +//===----------------------------------------------------------------------===// + +// TODO(b/130064155): Merge Constraints into Pred. + +// Base class for named constraints. +// +// An op's operands/attributes/results can have various requirements, e.g., +// having certain types, having values inside a certain range, and so on. +// Besides, for a graph rewrite rule, the source pattern used to match against +// the existing graph has conditions, like the op's operand must be of a more +// constrained subtype, the attribute must have a certain value, and so on. +// +// These requirements and conditions are modeled using this class. Records of +// this class are used to generate verification code in op verifier, and +// matching code in pattern matcher. +// +// Constraints are predicates with descriptive names, to facilitate inspection, +// provide nice error messages, etc. +class Constraint { + // The predicates that this constraint requires. + Pred predicate = pred; + // User-readable description used in error reporting messages. If empty, a + // generic message will be used. + string description = desc; +} + +// Subclasses used to differentiate different constraint kinds. These are used +// as markers for the TableGen backend to handle different constraint kinds +// differently if needed. Constraints not deriving from the following subclasses +// are considered as uncategorized constraints. + +// Subclass for constraints on a type. +class TypeConstraint : + Constraint; + +// Subclass for constraints on an attribute. +class AttrConstraint : + Constraint; + +// Subclass for constraints on a region. +class RegionConstraint : + Constraint; + +// How to use these constraint categories: +// +// * Use TypeConstraint to specify +// * Constraints on an op's operand/result definition +// * Further constraints to match an op's operand/result in source pattern +// +// * Use Attr (a subclass for AttrConstraint) for +// * Constraints on an op's attribute definition +// * Use AttrConstraint to specify +// * Further constraints to match an op's attribute in source pattern +// +// * Use uncategorized constraint to specify +// * Multi-entity constraints in rewrite rules + +//===----------------------------------------------------------------------===// +// Common predicates +//===----------------------------------------------------------------------===// + +// Whether a type is a VectorType. +def IsVectorTypePred : CPred<"$_self.isa()">; + +// Whether a type is a TensorType. +def IsTensorTypePred : CPred<"$_self.isa()">; + +// Whether a type is a MemRefType. +def IsMemRefTypePred : CPred<"$_self.isa()">; + +// Whether a type is an IsUnrankedMemRefType +def IsUnrankedMemRefTypePred : CPred<"$_self.isa()">; + +// Whether a type is a ShapedType. +def IsShapedTypePred : CPred<"$_self.isa()">; + +// For a ShapedType, verify that it has a static shape. +def HasStaticShapePred : CPred<"$_self.cast().hasStaticShape()">; + +// Whether a type is a TupleType. +def IsTupleTypePred : CPred<"$_self.isa()">; + +//===----------------------------------------------------------------------===// +// Dialect definitions +//===----------------------------------------------------------------------===// + +class Dialect { + // The name of the dialect. + string name = ?; + + // Short summary of the dialect. + string summary = ?; + + // The description of the dialect. + string description = ?; + + // The C++ namespace that ops of this dialect should be placed into. + // + // By default, uses the name of the dialect as the only namespace. To avoid + // placing in any namespace, use "". To specify nested namespaces, use "::" + // as the delimiter, e.g., given "A::B", ops will be placed in + // `namespace A { namespace B { } }`. + // + // Note that this works in conjunction with dialect C++ code. Depending on how + // the generated files are included into the dialect, you may want to specify + // a full namespace path or a partial one. + string cppNamespace = name; +} + +//===----------------------------------------------------------------------===// +// Type definitions +//===----------------------------------------------------------------------===// + +// A type, carries type constraints. +class Type : + TypeConstraint { + string typeDescription = ""; +} + +// Allows providing an alternative name and description to an existing type def. +class TypeAlias : + Type { + let typeDescription = t.typeDescription; +} + +// A type of a specific dialect. +class DialectType : + Type { + Dialect dialect = d; +} + +// A variadic type constraint. It expands to zero or more of the base type. This +// class is used for supporting variadic operands/results. An op can declare no +// more than one variadic operand/result, and that operand/result must be the +// last one in the operand/result list. +class Variadic : TypeConstraint { + Type baseType = type; +} + +// A type that can be constructed using MLIR::Builder. +// Note that this does not "inherit" from Type because it would require +// duplicating Type subclasses for buildable and non-buildable cases to avoid +// diamond "inheritance". +// TODO(zinenko): we may extend this to a more general 'Buildable' trait, +// making some Types and some Attrs buildable. +class BuildableType { + // The builder call to invoke (if specified) to construct the BuildableType. + // Format: this will be affixed to the builder. + code builderCall = builder; +} + +// Any type at all. +def AnyType : Type, "any type">; + +// None type +def NoneType : Type()">, "none type">; + +// Any type from the given list +class AnyTypeOf allowedTypes, string description = ""> : Type< + // Satisfy any of the allowed type's condition + Or, + !if(!eq(description, ""), + StrJoin.result, + description)>; + +// Integer types. +// Any integer type irrespective of its width. +def AnyInteger : Type()">, "integer">; + +// Index type. +def Index : Type()">, "index">; + +// Integer type of a specific width. +class I + : Type, + width # "-bit integer">, + BuildableType<"getIntegerType(" # width # ")"> { + int bitwidth = width; +} + +class IntOfWidths widths> : + AnyTypeOf), + StrJoinInt.result # "-bit integer">; + +def I1 : I<1>; +def I8 : I<8>; +def I16 : I<16>; +def I32 : I<32>; +def I64 : I<64>; + +// Floating point types. + +// Any float type irrespective of its width. +def AnyFloat : Type()">, "floating-point">; + +// Float type of a specific width. +class F + : Type, + width # "-bit float">, + BuildableType<"getF" # width # "Type()"> { + int bitwidth = width; +} + +class FloatOfWidths widths> : + AnyTypeOf), + StrJoinInt.result # "-bit float">; + +def F16 : F<16>; +def F32 : F<32>; +def F64 : F<64>; + +def BF16 : Type, "bfloat16 type">, + BuildableType<"getBF16Type()">; + +class Complex + : Type()">, + SubstLeaves<"$_self", "$_self.cast().getElementType()", + type.predicate>]>, + "complex type with " # type.description # " elements"> { + Type elementType = type; +} + +def AnyComplex : Type()">, "complex-type">; + +class OpaqueType + : Type, + description>; + +// Function Type + +// Any function type. +def FunctionType : Type()">, "function type">; + +// A container type is a type that has another type embedded within it. +class ContainerType : + // First, check the container predicate. Then, substitute the extracted + // element into the element type checker. + Type(elementTypeCall), + etype.predicate>]>, + descr # " of " # etype.description # " values"> { + // The type of elements in the container. + Type elementType = etype; + + // Call to retrieve. + code getElementTypeCall = elementTypeCall; +} + +class ShapedContainerType allowedTypes, Pred containerPred, string descr> : + ContainerType, containerPred, + "$_self.cast().getElementType()", descr>; + +// Whether a shaped type is ranked. +def HasRankPred : CPred<"$_self.cast().hasRank()">; + +// Whether a shaped type has one of the specified ranks. +class HasAnyRankOfPred ranks> : And<[ + HasRankPred, + Or().getRank() == " # rank>)>]>; + +// Vector types. + +class VectorOf allowedTypes> : + ShapedContainerType; + +// Whether the number of elements of a vector is from the given +// `allowedLengths` list +class IsVectorOfLengthPred allowedLengths> : + And<[IsVectorTypePred, + Or().getNumElements() + == }] + # allowedlength>)>]>; + +// Any vector where the number of elements is from the given +// `allowedLengths` list +class VectorOfLength allowedLengths> : Type< + IsVectorOfLengthPred, + " of length " # StrJoinInt.result>; + + +// Any vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` +// list +class VectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[VectorOf.predicate, + VectorOfLength.predicate]>, + VectorOf.description # + VectorOfLength.description>; + +def AnyVector : VectorOf<[AnyType]>; + +// Tensor types. + +// Any tensor type whose element type is from the given `allowedTypes` list +class TensorOf allowedTypes> : + ShapedContainerType; + +def AnyTensor : TensorOf<[AnyType]>; + +def AnyRankedTensor : + ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>, + "ranked tensor">; + +// TODO(b/130064155) Have an easy way to add another constraint to a type. +class StaticShapeTensorOf allowedTypes> + : Type.predicate, HasStaticShapePred]>, + "statically shaped " # TensorOf.description>; + +def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; + +def I1Tensor : TensorOf<[I1]>; +def I8Tensor : TensorOf<[I8]>; +def I16Tensor : TensorOf<[I16]>; +def I32Tensor : TensorOf<[I32]>; +def I64Tensor : TensorOf<[I64]>; + +def BF16Tensor : TensorOf<[BF16]>; +def F16Tensor : TensorOf<[F16]>; +def F32Tensor : TensorOf<[F32]>; +def F64Tensor : TensorOf<[F64]>; + +// Ranked tensor type with one of the specified types and ranks. +class TensorRankOf allowedTypes, list ranks> : + Type.predicate, HasAnyRankOfPred]>, + StrJoin.result # " " # + TensorOf.description>; + +class 0DTensorOf allowedTypes> : TensorRankOf; +class 1DTensorOf allowedTypes> : TensorRankOf; +class 2DTensorOf allowedTypes> : TensorRankOf; +class 3DTensorOf allowedTypes> : TensorRankOf; +class 4DTensorOf allowedTypes> : TensorRankOf; + +// Unranked Memref type +def AnyUnrankedMemRef : + ShapedContainerType<[AnyType], + IsUnrankedMemRefTypePred, "unranked.memref">; +// Memref type. + +// Memrefs are blocks of data with fixed type and rank. +class MemRefOf allowedTypes> : + ShapedContainerType; + +def AnyMemRef : MemRefOf<[AnyType]>; + +def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>; + +// Memref declarations handle any memref, independent of rank, size, (static or +// dynamic), layout, or memory space. +def I1MemRef : MemRefOf<[I1]>; +def I8MemRef : MemRefOf<[I8]>; +def I16MemRef : MemRefOf<[I16]>; +def I32MemRef : MemRefOf<[I32]>; +def I64MemRef : MemRefOf<[I64]>; + +def BF16MemRef : MemRefOf<[BF16]>; +def F16MemRef : MemRefOf<[F16]>; +def F32MemRef : MemRefOf<[F32]>; +def F64MemRef : MemRefOf<[F64]>; + +// TODO(b/130064155) Have an easy way to add another constraint to a type. +class MemRefRankOf allowedTypes, list ranks> : + Type.predicate, HasAnyRankOfPred]>, + StrJoin.result # " " # + MemRefOf.description>; + +class StaticShapeMemRefOf allowedTypes> + : Type.predicate, HasStaticShapePred]>, + "statically shaped " # MemRefOf.description>; + +def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; + +// For a MemRefType, verify that it has strides. +def HasStridesPred : CPred<[{ isStrided($_self.cast()) }]>; + +class StridedMemRefOf allowedTypes> + : Type.predicate, HasStridesPred]>, + "strided " # MemRefOf.description>; + +def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; + +class AnyStridedMemRefOfRank : + Type.predicate]>, + AnyStridedMemRef.description # " of rank " # rank>; + +// This represents a generic tuple without any constraints on element type. +def AnyTuple : Type; + +// A container type that has other types embedded in it, but (unlike +// ContainerType) can hold elements with a mix of types. Requires a call that +// produces a list of all elements' types. +class MixedContainerType : + Type< + And<[ + containerPred, + Concat< + "llvm::all_of(" # elementTypesCall # ", [](Type t) { return ", + SubstLeaves<"$_self", "t", etype.predicate>, + "; })" + > + ]>, + descr # " with any combination of " # etype.description # " values"> { + // The type of elements in the container. + Type elementType = etype; + + // Call to retrieve. + code getElementTypesCall = elementTypesCall; +} + +// A Tuple that holds a mix of elements of the allowed types. +class TupleOf allowedTypes> + : MixedContainerType, IsTupleTypePred, + "$_self.cast().getTypes()", "tuple">; + +// A Tuple with arbitrary nesting, where all elements are a mix of the allowed +// types. +class NestedTupleOf allowedTypes> : + MixedContainerType, IsTupleTypePred, + "getFlattenedTypes($_self.cast())", + "nested tuple">; + +//===----------------------------------------------------------------------===// +// Common type constraints +//===----------------------------------------------------------------------===// + +// Type constraint for bool-like types: bools, vectors of bools, tensors of +// bools. +def BoolLike : TypeConstraint.predicate, + TensorOf<[I1]>.predicate]>, + "bool-like">; + +// Type constraint for integer-like types: integers, indices, vectors of +// integers, tensors of integers. +def IntegerLike : TypeConstraint.predicate, TensorOf<[AnyInteger]>.predicate]>, + "integer-like">; + +// Type constraint for float-like types: floats, vectors or tensors thereof. +def FloatLike : TypeConstraint.predicate, TensorOf<[AnyFloat]>.predicate]>, + "floating-point-like">; + + +//===----------------------------------------------------------------------===// +// Attribute definitions +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Base attribute definition + +// Base class for all attributes. +class Attr : + AttrConstraint { + code storageType = ?; // The backing mlir::Attribute type + code returnType = ?; // The underlying C++ value type + + // The call expression to convert from the storage type to the return + // type. For example, an enum can be stored as an int but returned as an + // enum class. + // + // Format: $_self will be expanded to the attribute. + // + // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will + // expand to `getAttrOfType("val").getValue().getSExtValue()`. + code convertFromStorage = "$_self.getValue()"; + + // The call expression to build an attribute from a constant value. + // + // Format: $0 will be expanded to the constant value of the attribute. + // + // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will + // expand to `builder.getStringAttr("foo")`. + string constBuilderCall = ?; + + // Default value for attribute. + // Requires a constBuilderCall defined. + string defaultValue = ?; + + // Whether the attribute is optional. Typically requires a custom + // convertFromStorage method to handle the case where the attribute is + // not present. + bit isOptional = 0; + + // What is the base-level Attr instantiation that this Attr is built upon. + // Unset means this is a base-level Attr. + // + // This field is used by attribute wrapper classes (DefaultValuedAttr, + // OptionalAttr, etc.) to retrieve the base-level attribute definition. + // This can be used for getting its name; otherwise, we will see + // "anonymous_" as the attribute def name because of template + // instantiation. + // TOOD(b/132458159): deduplicate the fields in attribute wrapper classes. + Attr baseAttr = ?; +} + +//===----------------------------------------------------------------------===// +// Attribute modifier definition + +// Decorates an attribute to have an (unvalidated) default value if not present. +class DefaultValuedAttr : + Attr { + // Construct this attribute with the input attribute and change only + // the default value. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = attr.returnType; + let convertFromStorage = attr.convertFromStorage; + let constBuilderCall = attr.constBuilderCall; + let defaultValue = val; + + let baseAttr = attr; +} + +// Decorates an attribute as optional. The return type of the generated +// attribute accessor method will be Optional<>. +class OptionalAttr : Attr { + // Rewrite the attribute to be optional. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = "Optional<" # attr.returnType #">"; + let convertFromStorage = "$_self ? " # returnType # "(" # + attr.convertFromStorage # ") : (llvm::None)"; + let isOptional = 1; + + let baseAttr = attr; +} + +//===----------------------------------------------------------------------===// +// Primitive attribute kinds + +// A generic attribute that must be constructed around a specific type +// `attrValType`. Backed by MLIR attribute kind `attrKind`. +class TypedAttrBase : + Attr { + let constBuilderCall = "$_builder.get" # attrKind # "($_builder." # + attrValType.builderCall # ", $0)"; + let storageType = attrKind; +} + +// Any attribute. +def AnyAttr : Attr, "any attribute"> { + let storageType = "Attribute"; + let returnType = "Attribute"; + let convertFromStorage = "$_self"; + let constBuilderCall = "$0"; +} + +def BoolAttr : Attr()">, "bool attribute"> { + let storageType = [{ BoolAttr }]; + let returnType = [{ bool }]; + let constBuilderCall = "$_builder.getBoolAttr($0)"; +} + +// Base class for integer attributes of fixed width. +class IntegerAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[CPred<"$_self.isa()">, + CPred<"$_self.cast().getType()." + "isInteger(" # attrValType.bitwidth # ")">]>, + descr> { + let returnType = [{ APInt }]; +} + +def APIntAttr : Attr()">, + "arbitrary integer attribute"> { + let storageType = [{ IntegerAttr }]; + let returnType = [{ APInt }]; +} + +def I1Attr : IntegerAttrBase; +def I8Attr : IntegerAttrBase; +def I16Attr : IntegerAttrBase; +def I32Attr : IntegerAttrBase; +def I64Attr : IntegerAttrBase; + +class NonNegativeIntAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[IntegerAttrBase.predicate, + CPred<"!$_self.cast().getValue().isNegative()">]>, + descr> { + let returnType = [{ APInt }]; +} + +def NonNegativeI32Attr : NonNegativeIntAttrBase< + I32, "non-negative 32-bit integer attribute">; +def NonNegativeI64Attr : NonNegativeIntAttrBase< + I64, "non-negative 64-bit integer attribute">; + +class PositiveIntAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[IntegerAttrBase.predicate, + CPred<"$_self.cast().getValue()" + ".isStrictlyPositive()">]>, + descr> { + let returnType = [{ APInt }]; +} + +def PositiveI32Attr : PositiveIntAttrBase< + I32, "positive 32-bit integer attribute">; +def PositiveI64Attr : PositiveIntAttrBase< + I64, "positive 64-bit integer attribute">; + +// Base class for float attributes of fixed width. +class FloatAttrBase : + TypedAttrBase()">, + CPred<"$_self.cast().getType().isF" # + attrValType.bitwidth # "()">]>, + descr> { + let returnType = [{ APFloat }]; +} + +def F32Attr : FloatAttrBase; +def F64Attr : FloatAttrBase; + +// An attribute backed by a string type. +class StringBasedAttr : Attr { + let constBuilderCall = "$_builder.getStringAttr(\"$0\")"; + let storageType = [{ StringAttr }]; + let returnType = [{ StringRef }]; +} + +def StrAttr : StringBasedAttr()">, + "string attribute">; + +// Base class for attributes containing types. Example: +// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> +// defines a type attribute containing an integer type. +class TypeAttrBase : + Attr()">, + CPred<"$_self.cast().getValue().isa<" # retType # ">()">]>, + description> { + let storageType = [{ TypeAttr }]; + let returnType = retType; + let convertFromStorage = "$_self.getValue().cast<" # retType # ">()"; +} + +def TypeAttr : TypeAttrBase<"Type", "any type attribute">; + +// The mere presence of unit attributes has a meaning. Therefore, unit +// attributes are always treated as optional and accessors to them return +// "true" if the attribute is present and "false" otherwise. +def UnitAttr : Attr()">, "unit attribute"> { + let storageType = [{ UnitAttr }]; + let constBuilderCall = "$_builder.getUnitAttr()"; + let convertFromStorage = "$_self != nullptr"; + let returnType = "bool"; + let isOptional = 1; +} + +//===----------------------------------------------------------------------===// +// Enum attribute kinds + +// Additional information for an enum attribute case. +class EnumAttrCaseInfo { + // The C++ enumerant symbol + string symbol = sym; + + // The C++ enumerant value + // If less than zero, there will be no explicit discriminator values assigned + // to enumerators in the generated enum class. + int value = val; +} + +// An enum attribute case stored with StringAttr. +class StrEnumAttrCase : + EnumAttrCaseInfo, + StringBasedAttr< + CPred<"$_self.cast().getValue() == \"" # sym # "\"">, + "case " # sym>; + +// An enum attribute case stored with IntegerAttr. +class IntEnumAttrCaseBase : + EnumAttrCaseInfo, + IntegerAttrBase { + let predicate = + CPred<"$_self.cast().getInt() == " # val>; +} + +class I32EnumAttrCase : IntEnumAttrCaseBase; +class I64EnumAttrCase : IntEnumAttrCaseBase; + +// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the +// ordinal number of the bit that is set. It is the 32-bit integer with only +// one bit set. +class BitEnumAttrCase : + EnumAttrCaseInfo, + IntegerAttrBase { + let predicate = CPred< + "$_self.cast().getValue().getZExtValue() & " # val # "u">; +} + +// Additional information for an enum attribute. +class EnumAttrInfo cases> { + // The C++ enum class name + string className = name; + + // List of all accepted cases + list enumerants = cases; + + // The following fields are only used by the EnumsGen backend to generate + // an enum class definition and conversion utility functions. + + // The underlying type for the C++ enum class. An empty string mean the + // underlying type is not explicitly specified. + string underlyingType = ""; + + // The C++ namespaces that the enum class definition and utility functions + // should be placed into. + // + // Normally you want to place the full namespace path here. If it is nested, + // use "::" as the delimiter, e.g., given "A::B", generated code will be + // placed in `namespace A { namespace B { ... } }`. To avoid placing in any + // namespace, use "". + // TODO(b/134741431): use dialect to provide the namespace. + string cppNamespace = ""; + + // The name of the utility function that converts a value of the underlying + // type to the corresponding symbol. It will have the following signature: + // + // ```c++ + // llvm::Optional<> (); + // ``` + string underlyingToSymbolFnName = "symbolize" # name; + + // The name of the utility function that converts a string to the + // corresponding symbol. It will have the following signature: + // + // ```c++ + // llvm::Optional<> (llvm::StringRef); + // ``` + string stringToSymbolFnName = "symbolize" # name; + + // The name of the utility function that converts a symbol to the + // corresponding string. It will have the following signature: + // + // ```c++ + // (); + // ``` + string symbolToStringFnName = "stringify" # name; + string symbolToStringFnRetType = "llvm::StringRef"; + + // The name of the utility function that returns the max enum value used + // within the enum class. It will have the following signature: + // + // ```c++ + // static constexpr unsigned (); + // ``` + string maxEnumValFnName = "getMaxEnumValFor" # name; +} + +// An enum attribute backed by StringAttr. +// +// Op attributes of this kind are stored as StringAttr. Extra verification will +// be generated on the string though: only the symbols of the allowed cases are +// permitted as the string value. +class StrEnumAttr cases> : + EnumAttrInfo, + StringBasedAttr< + And<[StrAttr.predicate, Or]>, + !if(!empty(description), "allowed string cases: " # + StrJoin.result, + description)>; + +// An enum attribute backed by IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer though: only the values of the allowed cases are +// permitted as the integer value. +class IntEnumAttr cases> : + EnumAttrInfo, + IntegerAttrBase.result, description)> { + let predicate = And<[ + IntegerAttrBase.predicate, + Or]>; +} + +class I32EnumAttr cases> : + IntEnumAttr { + let returnType = cppNamespace # "::" # name; + let underlyingType = "uint32_t"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; +} +class I64EnumAttr cases> : + IntEnumAttr { + let returnType = cppNamespace # "::" # name; + let underlyingType = "uint64_t"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI64IntegerAttr(static_cast($0))"; +} + +// A bit enum stored with 32-bit IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer to make sure only allowed bit are set. Besides, +// helper methods are generated to parse a string separated with a specified +// delimiter to a symbol and vice versa. +class BitEnumAttr cases> : + EnumAttrInfo, IntegerAttrBase { + let predicate = And<[ + IntegerAttrBase.predicate, + // Make sure we don't have unknown bit set. + CPred<"!($_self.cast().getValue().getZExtValue() & (~(" # + StrJoin.result # + ")))"> + ]>; + + let returnType = cppNamespace # "::" # name; + let underlyingType = "uint32_t"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; + + // We need to return a string because we may concatenate symbols for multiple + // bits together. + let symbolToStringFnRetType = "std::string"; + + // The delimiter used to separate bit enum cases in strings. + string separator = "|"; +} + +//===----------------------------------------------------------------------===// +// Composite attribute kinds + +class DictionaryAttrBase : Attr()">, + "dictionary of named attribute values"> { + let storageType = [{ DictionaryAttr }]; + let returnType = [{ DictionaryAttr }]; + let convertFromStorage = "$_self"; +} + +def DictionaryAttr : DictionaryAttrBase; + +class ElementsAttrBase : + Attr { + let storageType = [{ ElementsAttr }]; + let returnType = [{ ElementsAttr }]; + let convertFromStorage = "$_self"; +} + +def ElementsAttr : ElementsAttrBase()">, + "constant vector/tensor attribute">; + +class IntElementsAttr : ElementsAttrBase< + CPred<"$_self.isa() &&" + "$_self.cast().getType()." + "getElementType().isInteger(" # width # ")">, + width # "-bit integer elements attribute"> { + + let storageType = [{ DenseIntElementsAttr }]; + let returnType = [{ DenseIntElementsAttr }]; + + // Note that this is only constructing scalar elements attribute. + let constBuilderCall = "DenseElementsAttr::get(" + "RankedTensorType::get({}, $_builder.getIntegerType(" # width # ")), " + "llvm::makeArrayRef($0)).cast()"; + let convertFromStorage = "$_self"; +} + +def I32ElementsAttr : IntElementsAttr<32>; +def I64ElementsAttr : IntElementsAttr<64>; + +class FloatElementsAttr : ElementsAttrBase< + CPred<"$_self.isa() &&" + "$_self.cast().getType()." + "getElementType().isF" # width # "()">, + width # "-bit float elements attribute"> { + + let storageType = [{ DenseElementsAttr }]; + let returnType = [{ DenseElementsAttr }]; + + // Note that this is only constructing scalar elements attribute. + let constBuilderCall = "DenseElementsAttr::get(" + "RankedTensorType::get({}, $_builder.getF" # width # "Type())," + "llvm::makeArrayRef($0))"; + let convertFromStorage = "$_self"; +} + +def F64ElementsAttr : FloatElementsAttr<64>; + +// A `width`-bit floating point elements attribute. The attribute should be +// ranked and has a shape as specified in `dims`. +class RankedFloatElementsAttr dims> : ElementsAttrBase< + CPred<"$_self.isa() &&" + "$_self.cast().getType()." + "getElementType().isF" # width # "() && " + // Check that this is ranked and has the specified shape. + "$_self.cast().getType().hasRank() && " + "$_self.cast().getType().getShape() == " + "llvm::ArrayRef({" # StrJoinInt.result # "})">, + width # "-bit float elements attribute of shape [" # + StrJoinInt.result # "]"> { + + let storageType = [{ DenseFPElementsAttr }]; + let returnType = [{ DenseFPElementsAttr }]; + + let constBuilderCall = "DenseElementsAttr::get(" + "RankedTensorType::get({" # StrJoinInt.result # + "}, $_builder.getF" # width # "Type()), " + "llvm::makeArrayRef($0)).cast()"; + let convertFromStorage = "$_self"; +} + +class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; +class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; + +// Base class for array attributes. +class ArrayAttrBase : + Attr { + let storageType = [{ ArrayAttr }]; + let returnType = [{ ArrayAttr }]; + let convertFromStorage = "$_self"; +} + +def ArrayAttr : ArrayAttrBase()">, + "array attribute">; + +// Base class for array attributes whose elements are of the same kind. +// `element` specifies the element attribute kind stored in this array. +class TypedArrayAttrBase: ArrayAttrBase< + And<[ + // Guarantee this is an ArrayAttr first + CPred<"$_self.isa()">, + // Guarantee all elements satisfy the constraints from `element` + Concat<"llvm::all_of($_self.cast(), " + "[](Attribute attr) { return ", + SubstLeaves<"$_self", "attr", element.predicate>, + "; })">]>, + description> { + let constBuilderCall = "$_builder.getArrayAttr($0)"; +} + +def I32ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; +} +def I64ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; +} +def F32ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getF32ArrayAttr($0)"; +} +def F64ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getF64ArrayAttr($0)"; +} +def StrArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getStrArrayAttr($0)"; +} +def TypeArrayAttr : TypedArrayAttrBase { + let constBuilderCall = ?; +} + +// Attribute information for an Attribute field within a StructAttr. +class StructFieldAttr { + // Name of this field in the StructAttr. + string name = thisName; + + // Attribute type wrapped by the struct attr. + Attr type = thisType; +} + +// Structured attribute that wraps a DictionaryAttr and provides both a +// validation method and set of accessors for a fixed set of fields. This is +// useful when representing data that would normally be in a structure. +class StructAttr attributes> : DictionaryAttrBase { + // Name for this StructAttr. + string className = name; + + // Return type should match the name of the structure. + let returnType = name; + + // Storage type should match the name of the structure. + let storageType = name; + + // The dialect this StructAttr belongs to. + Dialect structDialect = dialect; + + // List of fields that the StructAttr contains. + list fields = attributes; +} + +// Attributes containing symbol references. +def SymbolRefAttr : Attr()">, + "symbol reference attribute"> { + let storageType = [{ SymbolRefAttr }]; + let returnType = [{ SymbolRefAttr }]; + let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; + let convertFromStorage = "$_self"; +} +def FlatSymbolRefAttr : Attr()">, + "flat symbol reference attribute"> { + let storageType = [{ FlatSymbolRefAttr }]; + let returnType = [{ StringRef }]; + let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; + let convertFromStorage = "$_self.getValue()"; +} + +def SymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + +//===----------------------------------------------------------------------===// +// Derive attribute kinds + +// DerivedAttr are attributes whose value is computed from properties +// of the operation. They do not require additional storage and are +// materialized as needed. +class DerivedAttr : Attr, "derived attribute"> { + let returnType = ret; + code body = b; +} + +// Derived attribute that returns a mlir::Type. +class DerivedTypeAttr : DerivedAttr<"Type", body>; + +//===----------------------------------------------------------------------===// +// Constant attribute kinds + +// Represents a constant attribute of specific Attr type. A constant +// attribute can be specified only of attributes that have a constant +// builder call defined. The constant value is specified as a string. +// +// If used as a constraint, it generates a matcher on a constant attribute by +// using the constant value builder of the attribute and the value. +class ConstantAttr : AttrConstraint< + CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>, + "constant attribute " # val> { + Attr attr = attribute; + string value = val; +} + +class ConstF32Attr : ConstantAttr; +def ConstBoolAttrFalse : ConstantAttr; +def ConstBoolAttrTrue : ConstantAttr; +def ConstUnitAttr : ConstantAttr; + +//===----------------------------------------------------------------------===// +// Common attribute constraints +//===----------------------------------------------------------------------===// + +// A general mechanism to further confine the given `attr` with all the +// `constraints`. This allows to compose complex constraints out of a series +// of more primitive ones. +class Confined constraints> : Attr< + And, + !foldl(/*init*/attr.description, /*list*/constraints, + prev, cur, prev # " " # cur.description)> { + let storageType = attr.storageType; + let returnType = attr.returnType; + let convertFromStorage = attr.convertFromStorage; + let constBuilderCall = attr.constBuilderCall; + let defaultValue = attr.defaultValue; + let isOptional = attr.isOptional; + + let baseAttr = attr; +} + +// An AttrConstraint that holds if all attr constraints specified in +// 'constraints' hold. +class AllAttrConstraintsOf constraints> : AttrConstraint< + And, + !foldl(/*init*/!head(constraints).description, /*list*/!tail(constraints), + prev, cur, prev # " and " # cur.description)> { +} + +class IntMinValue : AttrConstraint< + CPred<"$_self.cast().getInt() >= " # n>, + "whose minimum value is " # n>; + +class IntMaxValue : AttrConstraint< + CPred<"$_self.cast().getInt() <= " # n>, + "whose maximum value is " # n>; + +class ArrayMinCount : AttrConstraint< + CPred<"$_self.cast().size() >= " # n>, + "with at least " # n # " elements">; + +class ArrayCount : AttrConstraint< + CPred<"$_self.cast().size() == " #n>, + "with exactly " # n # " elements">; + +class IntArrayNthElemEq : AttrConstraint< + And<[ + CPred<"$_self.cast().size() > " # index>, + CPred<"$_self.cast().getValue()[" # index # "]" + ".cast().getInt() == " # value> + ]>, + "whose " # index # "-th element must be " # value>; + +class IntArrayNthElemMinValue : AttrConstraint< + And<[ + CPred<"$_self.cast().size() > " # index>, + CPred<"$_self.cast().getValue()[" # index # "]" + ".cast().getInt() >= " # min> + ]>, + "whose " # index # "-th element must be at least " # min>; + +def IsNullAttr : AttrConstraint< + CPred<"!$_self">, "empty attribute (for optional attributes)">; + +// An attribute constraint on FlatSymbolRefAttr that requires that the +// reference point to an op of `opClass` within the closest parent with a symbol +// table. +// TODO(riverriddle) Add support for nested symbol references. +class ReferToOp : AttrConstraint< + CPred<"isa_and_nonnull<" # opClass # ">(" + "::mlir::SymbolTable::lookupNearestSymbolFrom(" + "&$_op, $_self.cast().getValue()))">, + "referencing to a '" # opClass # "' symbol">; + +//===----------------------------------------------------------------------===// +// Region definitions +//===----------------------------------------------------------------------===// + +class Region : + RegionConstraint; + +// Any region. +def AnyRegion : Region, "any region">; + +// A region with the given number of blocks. +class SizedRegion : Region< + CPred<"$_self.getBlocks().size() == " # numBlocks>, + "region with " # numBlocks # " blocks">; + +//===----------------------------------------------------------------------===// +// OpTrait definitions +//===----------------------------------------------------------------------===// + +// OpTrait represents a trait regarding an op. +class OpTrait; + +// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The +// purpose to wrap around C++ symbol string with this class is to make +// traits specified for ops in TableGen less alien and more integrated. +class NativeOpTrait : OpTrait { + string trait = "OpTrait::" # prop; +} + +// ParamNativeOpTrait corresponds to the template-parameterized traits in the +// C++ implementation. MLIR uses nested class templates to implement such +// traits leading to constructs of the form "TraitName::Impl". Use +// the value in `prop` as the trait name and the value in `params` as +// parameters to construct the native trait class name. +class ParamNativeOpTrait + : NativeOpTrait::Impl">; + +// GenInternalOpTrait is an op trait that does not have direct C++ mapping but +// affects op definition generator internals, like how op builders and +// operand/attribute/result getters are generated. +class GenInternalOpTrait : OpTrait { + string trait = "OpTrait::" # prop; +} + +// PredOpTrait is an op trait implemented by way of a predicate on the op. +class PredOpTrait : OpTrait { + string description = descr; + Pred predicate = pred; +} + +// Op supports operand broadcast behavior. +def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">; +// X op Y == Y op X +def Commutative : NativeOpTrait<"IsCommutative">; +// Op behaves like a function. +def FunctionLike : NativeOpTrait<"FunctionLike">; +// Op is isolated from above. +def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; +// Op results are float or vectors/tensors thereof. +def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">; +// Op has no side effect. +def NoSideEffect : NativeOpTrait<"HasNoSideEffect">; +// Op has the same operand type. +def SameTypeOperands : NativeOpTrait<"SameTypeOperands">; +// Op has same shape for all operands. +def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; +// Op has same operand and result shape. +def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">; +// Op has the same operand and result type. +def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; +// Op has the same element type (or type itself, if scalar) for all operands. +def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">; +// Op has the same operand and result element type (or type itself, if scalar). +def SameOperandsAndResultElementType : + NativeOpTrait<"SameOperandsAndResultElementType">; +// Op is a symbol. +def Symbol : NativeOpTrait<"Symbol">; +// Op defines a symbol table. +def SymbolTable : NativeOpTrait<"SymbolTable">; +// Op is a terminator. +def Terminator : NativeOpTrait<"IsTerminator">; + +// Op's regions have a single block with the specified terminator. +class SingleBlockImplicitTerminator + : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>; + +// Op's parent operation is the provided one. +class HasParent + : ParamNativeOpTrait<"HasParent", op>; + +// Op result type is derived from the first attribute. If the attribute is an +// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the +// attribute content is used. +def FirstAttrDerivedResultType : + GenInternalOpTrait<"FirstAttrDerivedResultType">; + +// TODO(antiagainst): Turn the following into normal traits and generate +// verification for them. + +// All variadic operands of the op have the same number of values. +// A variadic operand contains an array of values whose array size is only +// known at runtime. This trait requires all variadic operands of an op +// to have the same array size. +def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; +// All variadic results of the op have the same number of values. +// A variadic result contains an array of values whose array size is only +// known at runtime. This trait requires all variadic results of an op +// to have the same array size. +def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; + +// Uses an attribute named `operand_segment_sizes` to specify how many actual +// operand each ODS-declared operand (variadic or not) corresponds to. +// This trait is used for ops that have multiple variadic operands but do +// not know statically their size relationship. The attribute must be a 1D +// vector that has the same number of elements as the number of ODS declared +// operands. That means even if some operands are non-variadic, the attribute +// still need to have an element for its size, which is always 1. +def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">; +// Similar to AttrSizedOperandSegments, but used for results. The attribute +// should be named as `result_segment_sizes`. +def AttrSizedResultSegments : NativeOpTrait<"AttrSizedResultSegments">; + +//===----------------------------------------------------------------------===// +// OpInterface definitions +//===----------------------------------------------------------------------===// + +// Marker used to identify the argument list for an op or interface method. +def ins; + +// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in +// C++. The purpose to wrap around C++ symbol string with this class is to make +// interfaces specified for ops in TableGen less alien and more integrated. +class OpInterfaceTrait : NativeOpTrait<""> { + let trait = name # "::Trait"; +} + +// This class represents a single, optionally static, interface method. +// Note: non-static interface methods have an implicit 'op' parameter +// corresponding to an instance of the derived operation. +class InterfaceMethod { + // A human-readable description of what this method does. + string description = desc; + + // The name of the interface method. + string name = methodName; + + // The c++ type-name of the return type. + string returnType = retTy; + + // A dag of string that correspond to the arguments of the method. + dag arguments = args; + + // An optional body to the method. + code body = methodBody; + + // An optional default implementation of the method. + code defaultBody = defaultImplementation; +} + +// This class represents a single static interface method. +class StaticInterfaceMethod + : InterfaceMethod; + +// OpInterface represents an interface regarding an op. +class OpInterface : OpInterfaceTrait { + // A human-readable description of what this interface does. + string description = ""; + + // The name given to the c++ interface class. + string cppClassName = name; + + // The list of methods defined by this interface. + list methods = []; +} + +// Whether to declare the op interface methods in the op's header. This class +// simply wraps an OpInterface but is used to indicate that the method +// declarations should be generated. +class DeclareOpInterfaceMethods : + OpInterface { + let description = interface.description; + let cppClassName = interface.cppClassName; + let methods = interface.methods; +} + +//===----------------------------------------------------------------------===// +// Op definitions +//===----------------------------------------------------------------------===// + +// Marker used to identify the result list for an op. +def outs; + +// Marker used to identify the region list for an op. +def region; + +// Class for defining a custom builder. +// +// TableGen generates several generic builders for each op by default (see +// comment in the `Op` class). If the default generated ones cannot cover +// some use case, custom builders can be defined using instances of this class. +// +// The signature of the builder is always +// +// ```c++ +// static void build(Builder *builder, OperationState &state, +// ...) { +// ... +// } +// ``` +// +// To define a custom builder, the parameter list (*including* the `Builder +// *builder, OperationState &state` part) and body should be passed in +// as separate template arguments to this class. This is because we generate +// op declaration and definition into separate files. If an empty string is +// passed in for `body`, then *only* the builder declaration will be +// generated; this provides a way to define complicated builders entirely +// in C++. +class OpBuilder { + string params = p; + code body = b; +} + +// Base class for all ops. +class Op props = []> { + // The dialect of the op. + Dialect opDialect = dialect; + + // The mnemonic of the op. + string opName = mnemonic; + + // One-line human-readable description of what the op does. + string summary = ""; + + // Additional, longer human-readable description of what the op does. + string description = ""; + + // Dag containing the arguments of the op. Default to 0 arguments. + dag arguments = (ins); + + // The list of results of the op. Default to 0 results. + dag results = (outs); + + // The list of regions of the op. Default to 0 regions. + dag regions = (region); + + // Attribute getters can be added to the op by adding an Attr member + // with the name and type of the attribute. E.g., adding int attribute + // with name "value" and type "i32": + // I32Attr value; + + // Define the hooks used for building, parsing, printing, verification. + + // Custom builder. + // In addition to the custom builder provided here, and unless + // skipDefaultBuilders is set, two default builders are generated, with the + // following signatures: + // + // ```c++ + // static void build(Builder *, OperationState &tblgen_state, + // Type , Type , ..., + // Value , Value , ..., + // Attribute , Attribute , ...); + // ``` + // * where the attributes follow the same declaration order as in the op. + // + // ```c++ + // static void build(Builder *, OperationState &tblgen_state, + // ArrayRef resultTypes, + // ArrayRef operands, + // ArrayRef attributes); + // ``` + list builders = ?; + + // Avoid generating default build functions. Custom builders must be + // provided. + bit skipDefaultBuilders = 0; + + // Custom parser. + code parser = ?; + + // Custom printer. + code printer = ?; + + // Custom verifier. + code verifier = ?; + + // Whether this op has associated canonicalization patterns. + // TODO(b/120163349): figure out a better way to write canonicalization + // patterns in TableGen rules directly instead of using this marker + // and C++ implementations. + bit hasCanonicalizer = 0; + + // Whether this op has a folder. + bit hasFolder = 0; + + // Op traits. + // Note: The list of traits will be uniqued by ODS. + list traits = props; + + // Additional code that will be added to the public part of the generated + // C++ code of the op declaration. + code extraClassDeclaration = ?; +} + +// The arguments of an op. +class Arguments { + dag arguments = args; +} + +// The results of an op. +class Results { + dag results = rets; +} + +//===----------------------------------------------------------------------===// +// Common value constraints +//===----------------------------------------------------------------------===// + +def HasNoUseOf: Constraint< + CPred<"$_self->use_begin() == $_self->use_end()">, "has no use">; + +//===----------------------------------------------------------------------===// +// Common op type constraints +//===----------------------------------------------------------------------===// + +// These traits are for verifying properties of an op that require knowledge of +// multiple arguments or results. For verifying properties of a single argument +// or result, prefer operand type constraints. + +// These traits often require including "mlir/IR/TypeUtilities.h". + +// TODO(b/135033717): Improve the autogenerated error messages. + +class Rank : + StrFunc<"$" # name # ".getType().cast().getRank()">; + +class Shape : + StrFunc<"$" # name # ".getType().cast().getShape()">; + +class ElementCount : + StrFunc<"$" # name # ".getType().cast().getNumElements()">; + +class ElementType : StrFunc<"getElementTypeOrSelf($" # name # ")">; + +class AllMatchPred values> : + CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin.result #"}))">; + +class AllMatch values, string description> : + PredOpTrait>; + +// TODO(b/135032064): Only works for non-variadic. +class AllMatchSameOperatorPred names, string operator> : + AllMatchPred; + +class AllMatchSameOperatorTrait names, string operator, + string description> : + PredOpTrait< + "all of {" # StrJoin.result # "} have same " # description, + AllMatchSameOperatorPred>; + +class AllElementCountsMatch names> : + AllMatchSameOperatorTrait.result, + "element count">; + +class AllElementTypesMatch names> : + AllMatchSameOperatorTrait.result, + "element type">; + +class AllRanksMatch names> : + AllMatchSameOperatorTrait.result, "rank">; + +class AllShapesMatch names> : + AllMatchSameOperatorTrait.result, "shape">; + +class AllTypesMatch names> : + AllMatchSameOperatorTrait; + +// Type Constraint operand `idx`'s Element type is `type`. +class TCopVTEtIs : And<[ + CPred<"$_op.getNumOperands() > " # idx>, + SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()", + IsShapedTypePred>, + SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # idx # "))", + type.predicate>]>; + +// Predicate to verify that a named argument or result's element type matches a +// given type. +class TypeIsPred : + SubstLeaves<"$_self", "$" # name # ".getType()", type.predicate>; +class TypeIs : PredOpTrait< + "'" # name # "' is " # type.description, TypeIsPred>; + +// Predicate to verify that a named argument or result's element type matches a +// given type. +class ElementTypeIsPred : And<[ + SubstLeaves<"$_self", "$" # name # ".getType()", IsShapedTypePred>, + SubstLeaves<"$_self", "getElementTypeOrSelf($" # name # ")", + type.predicate>]>; +class ElementTypeIs : PredOpTrait< + "'" # name # "' is " # type.description, ElementTypeIsPred>; + +// Predicate to verify that the i'th operand and the j'th operand have the same +// elemental type. +// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element +// type. +class TCopVTEtIsSameAs : And<[ + CPred<"$_op.getNumOperands() > std::max(" # i # "u," # j # "u)">, + SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()", + IsShapedTypePred>, + SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", + IsShapedTypePred>, + CPred<"mlir::getElementTypeOrSelf($_op.getOperand(" # i # ")) == " + "mlir::getElementTypeOrSelf($_op.getOperand(" # j # "))">]>; + +// Predicate to verify that the i'th result and the j'th operand exist and has +// shaped types. +class TCOpResIsShapedTypePred : And<[ + CPred<"$_op.getNumResults() > " # i>, + CPred<"$_op.getNumOperands() > " # j>, + SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()", + IsShapedTypePred>, + SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", + IsShapedTypePred>]>; + +// Predicate to verify that the i'th result and the j'th operand have the same +// type. +class TCresIsSameAsOpBase : + CPred<"$_op.getResult(" # i # ")->getType() == " + "$_op.getOperand(" # j # ")->getType()">; + +// Basic Predicate to verify that the i'th result and the j'th operand have the +// same elemental type. +class TCresVTEtIsSameAsOpBase : + CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")) == " + "getElementTypeOrSelf($_op.getOperand(" # j # "))">; + +// Predicate to verify that the i'th result and the j'th operand have the same +// elemental type. +// Type Constraint result`i`'s Element type is Same As Operand `j`'s Element +// type. +class TCresVTEtIsSameAsOp : And<[ + TCOpResIsShapedTypePred, + TCresVTEtIsSameAsOpBase]>; + +// Predicate to verify that the opId'th operand can be broadcasted to the type +// of the resId'th result. +class TCOpIsBroadcastableToRes : And<[ + TCOpResIsShapedTypePred, + CPred<"OpTrait::util::getBroadcastedType(" + "$_op.getOperand(" # opId # ")->getType(), " + "$_op.getResult(" # resId # ")->getType())">]>; + +// Predicate to verify that all the operands at the given `indices` +// have the same element type. +// Type Constraint operands' Element type are all Same At the given `indices`. +// We query the operands' types into a list and check they are all the same. +// Precondition: +// 1) all operands involved are of shaped type and +// 2) the indices are not out of range. +class TCopVTEtAreSameAt indices> : CPred< + "llvm::is_splat(mlir::functional::map(" + "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }, " + "llvm::ArrayRef({" # StrJoinInt.result # "})))">; + +//===----------------------------------------------------------------------===// +// Pattern definitions +//===----------------------------------------------------------------------===// + +// Marker used to identify the delta value added to the default benefit value. +def addBenefit; + +// Base class for op+ -> op+ rewrite rules. These allow declaratively +// specifying rewrite rules. +// +// A rewrite rule contains two components: a source pattern and one or more +// result patterns. Each pattern is specified as a (recursive) DAG node (tree) +// in the form of `(node arg0, arg1, ...)`. +// +// The `node` are normally MLIR ops, but it can also be one of the directives +// listed later in this section. +// +// ## Symbol binding +// +// In the source pattern, `argN` can be used to specify matchers (e.g., using +// type/attribute type constraints, etc.) and bound to a name for later use. +// We can also bound names to op instances to reference them later in +// multi-entity constraints. +// +// In the result pattern, `argN` can be used to refer to a previously bound +// name, with potential transformations (e.g., using tAttr, etc.). `argN` can +// itself be nested DAG node. We can also bound names to ops to reference +// them later in other result patterns. +// +// For example, +// +// ``` +// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1), +// [(OneResultOp2:$op2 $arg0, $arg1), +// (OneResultOp3 $op2 (OneResultOp4))], +// [(HasStaticShapePred $op1)]>; +// ``` +// +// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to +// build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to +// check whether the result's shape is static. `$op2` is bound to +// `OneResultOp2` and used to build `OneResultOp3`. +// +// ## Multi-result op +// +// To create multi-result ops in result pattern, you can use a syntax similar +// to uni-result op, and it will act as a value pack for all results: +// +// ``` +// def : Pattern<(ThreeResultOp ...), +// [(TwoResultOp ...), (OneResultOp ...)]>; +// ``` +// +// Then `TwoResultOp` will replace the first two values of `ThreeResultOp`. +// +// You can also use `$__N` to explicitly access the N-th result. +// ``` +// def : Pattern<(FiveResultOp ...), +// [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0), +// (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>; +// ``` +// +// Then the values generated by `FiveResultOp` will be replaced by +// +// * `FiveResultOp`#0: `TwoResultOp1`#1 +// * `FiveResultOp`#1: `TwoResultOp1`#0 +// * `FiveResultOp`#2: `TwoResultOp2`#0 +// * `FiveResultOp`#3: `TwoResultOp2`#1 +// * `FiveResultOp`#4: `TwoResultOp2`#1 +class Pattern results, list preds = [], + dag benefitAdded = (addBenefit 0)> { + dag sourcePattern = source; + // Result patterns. Each result pattern is expected to replace one result + // of the root op in the source pattern. In the case of more result patterns + // than needed to replace the source op, only the last N results generated + // by the last N result pattern is used to replace a N-result source op. + // So that the beginning result patterns can be used to generate additional + // ops to aid building the results used for replacement. + list resultPatterns = results; + // Multi-entity constraints. Each constraint here involves multiple entities + // matched in source pattern and places further constraints on them as a + // whole. + list constraints = preds; + // The delta value added to the default benefit value. The default value is + // the number of ops in the source pattern. The rule with the highest final + // benefit value will be applied first if there are multiple rules matches. + // This delta value can be either positive or negative. + dag benefitDelta = benefitAdded; +} + +// Form of a pattern which produces a single result. +class Pat preds = [], + dag benefitAdded = (addBenefit 0)> : + Pattern; + +// Native code call wrapper. This allows invoking an arbitrary C++ expression +// to create an op operand/attribute or replace an op result. +// +// ## Placeholders +// +// If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`, +// the wrapped expression can take special placeholders listed below: +// +// * `$_builder` will be replaced by the current `mlir::PatternRewriter`. +// * `$_self` will be replaced with the entity this transformer is attached to. +// E.g., with the definition `def transform : NativeCodeCall<"$_self...">`, +// `$_self` in `transform:$attr` will be replaced by the value for `$attr`. +// +// If used as a DAG node, i.e., `(NativeCodeCall<"..."> , ..., )`, +// then positional placeholders are also supported; placeholder `$N` in the +// wrapped C++ expression will be replaced by ``. + +class NativeCodeCall { + string expression = expr; +} + +//===----------------------------------------------------------------------===// +// Common directives +//===----------------------------------------------------------------------===// + +// Directive used in result pattern to indicate that no new op are generated, +// so to replace the matched DAG with an existing SSA value. +def replaceWithValue; + +#endif // OP_BASE diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h new file mode 100644 index 0000000000000000000000000000000000000000..1abf82f37ee4623da5bf9ca4363a7f73dca601ba --- /dev/null +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -0,0 +1,1225 @@ +//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements helper classes for implementing the "Op" types. This +// includes the Op type, which is the base class for Op class definitions, +// as well as number of traits in the OpTrait namespace that provide a +// declarative way to specify properties of Ops. +// +// The purpose of these types are to allow light-weight implementation of +// concrete ops (like DimOp) with very little boilerplate. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_OPDEFINITION_H +#define MLIR_IR_OPDEFINITION_H + +#include "mlir/IR/Operation.h" +#include + +namespace mlir { +class Builder; + +namespace OpTrait { +template class OneResult; +} + +/// This class represents success/failure for operation parsing. It is +/// essentially a simple wrapper class around LogicalResult that allows for +/// explicit conversion to bool. This allows for the parser to chain together +/// parse rules without the clutter of "failed/succeeded". +class ParseResult : public LogicalResult { +public: + ParseResult(LogicalResult result = success()) : LogicalResult(result) {} + + // Allow diagnostics emitted during parsing to be converted to failure. + ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {} + ParseResult(const Diagnostic &) : LogicalResult(failure()) {} + + /// Failure is true in a boolean context. + explicit operator bool() const { return failed(*this); } +}; +/// This class implements `Optional` functionality for ParseResult. We don't +/// directly use Optional here, because it provides an implicit conversion +/// to 'bool' which we want to avoid. This class is used to implement tri-state +/// 'parseOptional' functions that may have a failure mode when parsing that +/// shouldn't be attributed to "not present". +class OptionalParseResult { +public: + OptionalParseResult() = default; + OptionalParseResult(LogicalResult result) : impl(result) {} + OptionalParseResult(ParseResult result) : impl(result) {} + OptionalParseResult(const InFlightDiagnostic &) + : OptionalParseResult(failure()) {} + OptionalParseResult(llvm::NoneType) : impl(llvm::None) {} + + /// Returns true if we contain a valid ParseResult value. + bool hasValue() const { return impl.hasValue(); } + + /// Access the internal ParseResult value. + ParseResult getValue() const { return impl.getValue(); } + ParseResult operator*() const { return getValue(); } + +private: + Optional impl; +}; + +// These functions are out-of-line utilities, which avoids them being template +// instantiated/duplicated. +namespace impl { +/// Insert an operation, generated by `buildTerminatorOp`, at the end of the +/// region's only block if it does not have a terminator already. If the region +/// is empty, insert a new block first. `buildTerminatorOp` should return the +/// terminator operation to insert. +void ensureRegionTerminator(Region ®ion, Location loc, + function_ref buildTerminatorOp); +/// Templated version that fills the generates the provided operation type. +template +void ensureRegionTerminator(Region ®ion, Builder &builder, Location loc) { + ensureRegionTerminator(region, loc, [&] { + OperationState state(loc, OpTy::getOperationName()); + OpTy::build(&builder, state); + return Operation::create(state); + }); +} +} // namespace impl + +/// This is the concrete base class that holds the operation pointer and has +/// non-generic methods that only depend on State (to avoid having them +/// instantiated on template types that don't affect them. +/// +/// This also has the fallback implementations of customization hooks for when +/// they aren't customized. +class OpState { +public: + /// Ops are pointer-like, so we allow implicit conversion to bool. + operator bool() { return getOperation() != nullptr; } + + /// This implicitly converts to Operation*. + operator Operation *() const { return state; } + + /// Return the operation that this refers to. + Operation *getOperation() { return state; } + + /// Returns the closest surrounding operation that contains this operation + /// or nullptr if this is a top-level operation. + Operation *getParentOp() { return getOperation()->getParentOp(); } + + /// Return the closest surrounding parent operation that is of type 'OpTy'. + template OpTy getParentOfType() { + return getOperation()->getParentOfType(); + } + + /// Return the context this operation belongs to. + MLIRContext *getContext() { return getOperation()->getContext(); } + + /// Print the operation to the given stream. + void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { + state->print(os, flags); + } + + /// Dump this operation. + void dump() { state->dump(); } + + /// The source location the operation was defined or derived from. + Location getLoc() { return state->getLoc(); } + void setLoc(Location loc) { state->setLoc(loc); } + + /// Return all of the attributes on this operation. + ArrayRef getAttrs() { return state->getAttrs(); } + + /// A utility iterator that filters out non-dialect attributes. + using dialect_attr_iterator = Operation::dialect_attr_iterator; + using dialect_attr_range = Operation::dialect_attr_range; + + /// Return a range corresponding to the dialect attributes for this operation. + dialect_attr_range getDialectAttrs() { return state->getDialectAttrs(); } + dialect_attr_iterator dialect_attr_begin() { + return state->dialect_attr_begin(); + } + dialect_attr_iterator dialect_attr_end() { return state->dialect_attr_end(); } + + /// Return an attribute with the specified name. + Attribute getAttr(StringRef name) { return state->getAttr(name); } + + /// If the operation has an attribute of the specified type, return it. + template AttrClass getAttrOfType(StringRef name) { + return getAttr(name).dyn_cast_or_null(); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setAttr(Identifier name, Attribute value) { + state->setAttr(name, value); + } + void setAttr(StringRef name, Attribute value) { + setAttr(Identifier::get(name, getContext()), value); + } + + /// Set the attributes held by this operation. + void setAttrs(ArrayRef attributes) { + state->setAttrs(attributes); + } + void setAttrs(NamedAttributeList newAttrs) { state->setAttrs(newAttrs); } + + /// Set the dialect attributes for this operation, and preserve all dependent. + template void setDialectAttrs(DialectAttrs &&attrs) { + state->setDialectAttrs(std::move(attrs)); + } + + /// Remove the attribute with the specified name if it exists. The return + /// value indicates whether the attribute was present or not. + NamedAttributeList::RemoveResult removeAttr(Identifier name) { + return state->removeAttr(name); + } + NamedAttributeList::RemoveResult removeAttr(StringRef name) { + return state->removeAttr(Identifier::get(name, getContext())); + } + + /// Return true if there are no users of any results of this operation. + bool use_empty() { return state->use_empty(); } + + /// Remove this operation from its parent block and delete it. + void erase() { state->erase(); } + + /// Emit an error with the op name prefixed, like "'dim' op " which is + /// convenient for verifiers. + InFlightDiagnostic emitOpError(const Twine &message = {}); + + /// Emit an error about fatal conditions with this operation, reporting up to + /// any diagnostic handlers that may be listening. + InFlightDiagnostic emitError(const Twine &message = {}); + + /// Emit a warning about this operation, reporting up to any diagnostic + /// handlers that may be listening. + InFlightDiagnostic emitWarning(const Twine &message = {}); + + /// Emit a remark about this operation, reporting up to any diagnostic + /// handlers that may be listening. + InFlightDiagnostic emitRemark(const Twine &message = {}); + + /// Walk the operation in postorder, calling the callback for each nested + /// operation(including this one). + /// See Operation::walk for more details. + template > + RetT walk(FnT &&callback) { + return state->walk(std::forward(callback)); + } + + // These are default implementations of customization hooks. +public: + /// This hook returns any canonicalization pattern rewrites that the operation + /// supports, for use by the canonicalization pass. + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) {} + +protected: + /// If the concrete type didn't implement a custom verifier hook, just fall + /// back to this one which accepts everything. + LogicalResult verify() { return success(); } + + /// Unless overridden, the custom assembly form of an op is always rejected. + /// Op implementations should implement this to return failure. + /// On success, they should fill in result with the fields to use. + static ParseResult parse(OpAsmParser &parser, OperationState &result); + + // The fallback for the printer is to print it the generic assembly form. + void print(OpAsmPrinter &p); + + /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, + /// so we can cast it away here. + explicit OpState(Operation *state) : state(state) {} + +private: + Operation *state; +}; + +// Allow comparing operators. +inline bool operator==(OpState lhs, OpState rhs) { + return lhs.getOperation() == rhs.getOperation(); +} +inline bool operator!=(OpState lhs, OpState rhs) { + return lhs.getOperation() != rhs.getOperation(); +} + +/// This class represents a single result from folding an operation. +class OpFoldResult : public PointerUnion { + using PointerUnion::PointerUnion; +}; + +/// This template defines the foldHook as used by AbstractOperation. +/// +/// The default implementation uses a general fold method that can be defined on +/// custom ops which can return multiple results. +template +class FoldingHook { +public: + /// This is an implementation detail of the constant folder hook for + /// AbstractOperation. + static LogicalResult foldHook(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return cast(op).fold(operands, results); + } + + /// This hook implements a generalized folder for this operation. Operations + /// can implement this to provide simplifications rules that are applied by + /// the Builder::createOrFold API and the canonicalization pass. + /// + /// This is an intentionally limited interface - implementations of this hook + /// can only perform the following changes to the operation: + /// + /// 1. They can leave the operation alone and without changing the IR, and + /// return failure. + /// 2. They can mutate the operation in place, without changing anything else + /// in the IR. In this case, return success. + /// 3. They can return a list of existing values that can be used instead of + /// the operation. In this case, fill in the results list and return + /// success. The caller will remove the operation and use those results + /// instead. + /// + /// This allows expression of some simple in-place canonicalizations (e.g. + /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as + /// generalized constant folding. + /// + /// If not overridden, this fallback implementation always fails to fold. + /// + LogicalResult fold(ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + } +}; + +/// This template specialization defines the foldHook as used by +/// AbstractOperation for single-result operations. This gives the hook a nicer +/// signature that is easier to implement. +template +class FoldingHook::type> { +public: + /// If the operation returns a single value, then the Op can be implicitly + /// converted to an Value. This yields the value of the only result. + operator Value() { + return static_cast(this)->getOperation()->getResult(0); + } + + /// This is an implementation detail of the constant folder hook for + /// AbstractOperation. + static LogicalResult foldHook(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + auto result = cast(op).fold(operands); + if (!result) + return failure(); + + // Check if the operation was folded in place. In this case, the operation + // returns itself. + if (result.template dyn_cast() != op->getResult(0)) + results.push_back(result); + return success(); + } + + /// This hook implements a generalized folder for this operation. Operations + /// can implement this to provide simplifications rules that are applied by + /// the Builder::createOrFold API and the canonicalization pass. + /// + /// This is an intentionally limited interface - implementations of this hook + /// can only perform the following changes to the operation: + /// + /// 1. They can leave the operation alone and without changing the IR, and + /// return nullptr. + /// 2. They can mutate the operation in place, without changing anything else + /// in the IR. In this case, return the operation itself. + /// 3. They can return an existing SSA value that can be used instead of + /// the operation. In this case, return that value. The caller will + /// remove the operation and use that result instead. + /// + /// This allows expression of some simple in-place canonicalizations (e.g. + /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as + /// generalized constant folding. + /// + /// If not overridden, this fallback implementation always fails to fold. + /// + OpFoldResult fold(ArrayRef operands) { return {}; } +}; + +//===----------------------------------------------------------------------===// +// Operation Trait Types +//===----------------------------------------------------------------------===// + +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +LogicalResult verifyZeroOperands(Operation *op); +LogicalResult verifyOneOperand(Operation *op); +LogicalResult verifyNOperands(Operation *op, unsigned numOperands); +LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); +LogicalResult verifyOperandsAreFloatLike(Operation *op); +LogicalResult verifyOperandsAreIntegerLike(Operation *op); +LogicalResult verifySameTypeOperands(Operation *op); +LogicalResult verifyZeroResult(Operation *op); +LogicalResult verifyOneResult(Operation *op); +LogicalResult verifyNResults(Operation *op, unsigned numOperands); +LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); +LogicalResult verifySameOperandsShape(Operation *op); +LogicalResult verifySameOperandsAndResultShape(Operation *op); +LogicalResult verifySameOperandsElementType(Operation *op); +LogicalResult verifySameOperandsAndResultElementType(Operation *op); +LogicalResult verifySameOperandsAndResultType(Operation *op); +LogicalResult verifyResultsAreBoolLike(Operation *op); +LogicalResult verifyResultsAreFloatLike(Operation *op); +LogicalResult verifyResultsAreIntegerLike(Operation *op); +LogicalResult verifyIsTerminator(Operation *op); +LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); +LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); +} // namespace impl + +/// Helper class for implementing traits. Clients are not expected to interact +/// with this directly, so its members are all protected. +template class TraitType> +class TraitBase { +protected: + /// Return the ultimate Operation being worked on. + Operation *getOperation() { + // We have to cast up to the trait type, then to the concrete type, then to + // the BaseState class in explicit hops because the concrete type will + // multiply derive from the (content free) TraitBase class, and we need to + // be able to disambiguate the path for the C++ compiler. + auto *trait = static_cast *>(this); + auto *concrete = static_cast(trait); + auto *base = static_cast(concrete); + return base->getOperation(); + } + + /// Provide default implementations of trait hooks. This allows traits to + /// provide exactly the overrides they care about. + static LogicalResult verifyTrait(Operation *op) { return success(); } + static AbstractOperation::OperationProperties getTraitProperties() { + return 0; + } +}; + +namespace detail { +/// Utility trait base that provides accessors for derived traits that have +/// multiple operands. +template class TraitType> +struct MultiOperandTraitBase : public TraitBase { + using operand_iterator = Operation::operand_iterator; + using operand_range = Operation::operand_range; + using operand_type_iterator = Operation::operand_type_iterator; + using operand_type_range = Operation::operand_type_range; + + /// Return the number of operands. + unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } + + /// Return the operand at index 'i'. + Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } + + /// Set the operand at index 'i' to 'value'. + void setOperand(unsigned i, Value value) { + this->getOperation()->setOperand(i, value); + } + + /// Operand iterator access. + operand_iterator operand_begin() { + return this->getOperation()->operand_begin(); + } + operand_iterator operand_end() { return this->getOperation()->operand_end(); } + operand_range getOperands() { return this->getOperation()->getOperands(); } + + /// Operand type access. + operand_type_iterator operand_type_begin() { + return this->getOperation()->operand_type_begin(); + } + operand_type_iterator operand_type_end() { + return this->getOperation()->operand_type_end(); + } + operand_type_range getOperandTypes() { + return this->getOperation()->getOperandTypes(); + } +}; +} // end namespace detail + +/// This class provides the API for ops that are known to have no +/// SSA operand. +template +class ZeroOperands : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyZeroOperands(op); + } + +private: + // Disable these. + void getOperand() {} + void setOperand() {} +}; + +/// This class provides the API for ops that are known to have exactly one +/// SSA operand. +template +class OneOperand : public TraitBase { +public: + Value getOperand() { return this->getOperation()->getOperand(0); } + + void setOperand(Value value) { this->getOperation()->setOperand(0, value); } + + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOneOperand(op); + } +}; + +/// This class provides the API for ops that are known to have a specified +/// number of operands. This is used as a trait like this: +/// +/// class FooOp : public Op::Impl> { +/// +template class NOperands { +public: + static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); + + template + class Impl + : public detail::MultiOperandTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyNOperands(op, N); + } + }; +}; + +/// This class provides the API for ops that are known to have a at least a +/// specified number of operands. This is used as a trait like this: +/// +/// class FooOp : public Op::Impl> { +/// +template class AtLeastNOperands { +public: + template + class Impl : public detail::MultiOperandTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyAtLeastNOperands(op, N); + } + }; +}; + +/// This class provides the API for ops which have an unknown number of +/// SSA operands. +template +class VariadicOperands + : public detail::MultiOperandTraitBase {}; + +/// This class provides return value APIs for ops that are known to have +/// zero results. +template +class ZeroResult : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyZeroResult(op); + } +}; + +namespace detail { +/// Utility trait base that provides accessors for derived traits that have +/// multiple results. +template class TraitType> +struct MultiResultTraitBase : public TraitBase { + using result_iterator = Operation::result_iterator; + using result_range = Operation::result_range; + using result_type_iterator = Operation::result_type_iterator; + using result_type_range = Operation::result_type_range; + + /// Return the number of results. + unsigned getNumResults() { return this->getOperation()->getNumResults(); } + + /// Return the result at index 'i'. + Value getResult(unsigned i) { return this->getOperation()->getResult(i); } + + /// Replace all uses of results of this operation with the provided 'values'. + /// 'values' may correspond to an existing operation, or a range of 'Value'. + template void replaceAllUsesWith(ValuesT &&values) { + this->getOperation()->replaceAllUsesWith(std::forward(values)); + } + + /// Return the type of the `i`-th result. + Type getType(unsigned i) { return getResult(i)->getType(); } + + /// Result iterator access. + result_iterator result_begin() { + return this->getOperation()->result_begin(); + } + result_iterator result_end() { return this->getOperation()->result_end(); } + result_range getResults() { return this->getOperation()->getResults(); } + + /// Result type access. + result_type_iterator result_type_begin() { + return this->getOperation()->result_type_begin(); + } + result_type_iterator result_type_end() { + return this->getOperation()->result_type_end(); + } + result_type_range getResultTypes() { + return this->getOperation()->getResultTypes(); + } +}; +} // end namespace detail + +/// This class provides return value APIs for ops that are known to have a +/// single result. +template +class OneResult : public TraitBase { +public: + Value getResult() { return this->getOperation()->getResult(0); } + Type getType() { return getResult()->getType(); } + + /// Replace all uses of 'this' value with the new value, updating anything in + /// the IR that uses 'this' to use the other value instead. When this returns + /// there are zero uses of 'this'. + void replaceAllUsesWith(Value newValue) { + getResult()->replaceAllUsesWith(newValue); + } + + /// Replace all uses of 'this' value with the result of 'op'. + void replaceAllUsesWith(Operation *op) { + this->getOperation()->replaceAllUsesWith(op); + } + + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOneResult(op); + } +}; + +/// This class provides the API for ops that are known to have a specified +/// number of results. This is used as a trait like this: +/// +/// class FooOp : public Op::Impl> { +/// +template class NResults { +public: + static_assert(N > 1, "use ZeroResult/OneResult for N < 2"); + + template + class Impl + : public detail::MultiResultTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyNResults(op, N); + } + }; +}; + +/// This class provides the API for ops that are known to have at least a +/// specified number of results. This is used as a trait like this: +/// +/// class FooOp : public Op::Impl> { +/// +template class AtLeastNResults { +public: + template + class Impl : public detail::MultiResultTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyAtLeastNResults(op, N); + } + }; +}; + +/// This class provides the API for ops which have an unknown number of +/// results. +template +class VariadicResults + : public detail::MultiResultTraitBase {}; + +/// This class provides verification for ops that are known to have the same +/// operand shape: all operands are scalars, vectors/tensors of the same +/// shape. +template +class SameOperandsShape : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsShape(op); + } +}; + +/// This class provides verification for ops that are known to have the same +/// operand and result shape: both are scalars, vectors/tensors of the same +/// shape. +template +class SameOperandsAndResultShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultShape(op); + } +}; + +/// This class provides verification for ops that are known to have the same +/// operand element type (or the type itself if it is scalar). +/// +template +class SameOperandsElementType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsElementType(op); + } +}; + +/// This class provides verification for ops that are known to have the same +/// operand and result element type (or the type itself if it is scalar). +/// +template +class SameOperandsAndResultElementType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultElementType(op); + } +}; + +/// This class provides verification for ops that are known to have the same +/// operand and result type. +/// +/// Note: this trait subsumes the SameOperandsAndResultShape and +/// SameOperandsAndResultElementType traits. +template +class SameOperandsAndResultType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultType(op); + } +}; + +/// This class verifies that any results of the specified op have a boolean +/// type, a vector thereof, or a tensor thereof. +template +class ResultsAreBoolLike : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyResultsAreBoolLike(op); + } +}; + +/// This class verifies that any results of the specified op have a floating +/// point type, a vector thereof, or a tensor thereof. +template +class ResultsAreFloatLike + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyResultsAreFloatLike(op); + } +}; + +/// This class verifies that any results of the specified op have an integer or +/// index type, a vector thereof, or a tensor thereof. +template +class ResultsAreIntegerLike + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyResultsAreIntegerLike(op); + } +}; + +/// This class adds property that the operation is commutative. +template +class IsCommutative : public TraitBase { +public: + static AbstractOperation::OperationProperties getTraitProperties() { + return static_cast( + OperationProperty::Commutative); + } +}; + +/// This class adds property that the operation has no side effects. +template +class HasNoSideEffect : public TraitBase { +public: + static AbstractOperation::OperationProperties getTraitProperties() { + return static_cast( + OperationProperty::NoSideEffect); + } +}; + +/// This class verifies that all operands of the specified op have a float type, +/// a vector thereof, or a tensor thereof. +template +class OperandsAreFloatLike + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOperandsAreFloatLike(op); + } +}; + +/// This class verifies that all operands of the specified op have an integer or +/// index type, a vector thereof, or a tensor thereof. +template +class OperandsAreIntegerLike + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOperandsAreIntegerLike(op); + } +}; + +/// This class verifies that all operands of the specified op have the same +/// type. +template +class SameTypeOperands : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameTypeOperands(op); + } +}; + +/// This class provides the API for ops that are known to be terminators. +template +class IsTerminator : public TraitBase { +public: + static AbstractOperation::OperationProperties getTraitProperties() { + return static_cast( + OperationProperty::Terminator); + } + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyIsTerminator(op); + } + + unsigned getNumSuccessors() { + return this->getOperation()->getNumSuccessors(); + } + unsigned getNumSuccessorOperands(unsigned index) { + return this->getOperation()->getNumSuccessorOperands(index); + } + + Block *getSuccessor(unsigned index) { + return this->getOperation()->getSuccessor(index); + } + + void setSuccessor(Block *block, unsigned index) { + return this->getOperation()->setSuccessor(block, index); + } + + void addSuccessorOperand(unsigned index, Value value) { + return this->getOperation()->addSuccessorOperand(index, value); + } + void addSuccessorOperands(unsigned index, ArrayRef values) { + return this->getOperation()->addSuccessorOperand(index, values); + } +}; + +/// This class provides the API for ops that are known to be isolated from +/// above. +template +class IsIsolatedFromAbove + : public TraitBase { +public: + static AbstractOperation::OperationProperties getTraitProperties() { + return static_cast( + OperationProperty::IsolatedFromAbove); + } + static LogicalResult verifyTrait(Operation *op) { + for (auto ®ion : op->getRegions()) + if (!region.isIsolatedFromAbove(op->getLoc())) + return failure(); + return success(); + } +}; + +/// This class provides APIs and verifiers for ops with regions having a single +/// block that must terminate with `TerminatorOpType`. +template struct SingleBlockImplicitTerminator { + template + class Impl : public TraitBase { + public: + static LogicalResult verifyTrait(Operation *op) { + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + Region ®ion = op->getRegion(i); + + // Empty regions are fine. + if (region.empty()) + continue; + + // Non-empty regions must contain a single basic block. + if (std::next(region.begin()) != region.end()) + return op->emitOpError("expects region #") + << i << " to have 0 or 1 blocks"; + + Block &block = region.front(); + if (block.empty()) + return op->emitOpError() << "expects a non-empty block"; + Operation &terminator = block.back(); + if (isa(terminator)) + continue; + + return op->emitOpError("expects regions to end with '" + + TerminatorOpType::getOperationName() + + "', found '" + + terminator.getName().getStringRef() + "'") + .attachNote() + << "in custom textual format, the absence of terminator implies " + "'" + << TerminatorOpType::getOperationName() << '\''; + } + + return success(); + } + + /// Ensure that the given region has the terminator required by this trait. + static void ensureTerminator(Region ®ion, Builder &builder, + Location loc) { + ::mlir::impl::template ensureRegionTerminator( + region, builder, loc); + } + }; +}; + +/// This class provides a verifier for ops that are expecting a specific parent. +template struct HasParent { + template + class Impl : public TraitBase { + public: + static LogicalResult verifyTrait(Operation *op) { + if (isa(op->getParentOp())) + return success(); + return op->emitOpError() << "expects parent op '" + << ParentOpType::getOperationName() << "'"; + } + }; +}; + +/// A trait for operations that have an attribute specifying operand segments. +/// +/// Certain operations can have multiple variadic operands and their size +/// relationship is not always known statically. For such cases, we need +/// a per-op-instance specification to divide the operands into logical groups +/// or segments. This can be modeled by attributes. The attribute will be named +/// as `operand_segment_sizes`. +/// +/// This trait verifies the attribute for specifying operand segments has +/// the correct type (1D vector) and values (non-negative), etc. +template +class AttrSizedOperandSegments + : public TraitBase { +public: + static StringRef getOperandSegmentSizeAttr() { + return "operand_segment_sizes"; + } + + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyOperandSizeAttr( + op, getOperandSegmentSizeAttr()); + } +}; + +/// Similar to AttrSizedOperandSegments but used for results. +template +class AttrSizedResultSegments + : public TraitBase { +public: + static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; } + + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyResultSizeAttr( + op, getResultSegmentSizeAttr()); + } +}; + +} // end namespace OpTrait + +//===----------------------------------------------------------------------===// +// Operation Definition classes +//===----------------------------------------------------------------------===// + +/// This provides public APIs that all operations should have. The template +/// argument 'ConcreteType' should be the concrete type by CRTP and the others +/// are base classes by the policy pattern. +template class... Traits> +class Op : public OpState, + public Traits..., + public FoldingHook, + Traits...>::value> { +public: + /// Return if this operation contains the provided trait. + template