Skip to content
Unverified Commit 7c850867 authored by Maksim Levental's avatar Maksim Levental Committed by GitHub
Browse files

[mlir][python] value casting (#69644)

This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a
proxy class that overloads dunders such as `__add__`, `__sub__`, and
`__mul__` for fun and great profit.

This is thematically similar to
https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
and
https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f.
The example in the test demonstrates the value of the feature (no pun
intended):

```python
    @register_value_caster(F16Type.static_typeid)
    @register_value_caster(F32Type.static_typeid)
    @register_value_caster(F64Type.static_typeid)
    @register_value_caster(IntegerType.static_typeid)
    class ArithValue(Value):
        __add__ = partialmethod(_binary_op, op="add")
        __sub__ = partialmethod(_binary_op, op="sub")
        __mul__ = partialmethod(_binary_op, op="mul")

    a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
    b = a + a
    # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
    print(b)

    a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
    b = a - a
    # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
    print(b)

    a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
    b = a * a
    # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
    print(b)
```

**EDIT**: this now goes through the bindings and thus supports automatic
casting of `OpResult` (including as an element of `OpResultList`),
`BlockArgument` (including as an element of `BlockArgumentList`), as
well as `Value`.
parent 867ece18
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment