From 9f4414269c98753df06c5e04bc1300e8e0e1eb45 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 17 Apr 2024 21:00:21 +0200 Subject: [PATCH] refactor[next]: Use type specification for itir.Literal (#1529) Small PR in preparation of the new ITIR type system. Currently the type of a `itir.Literal` is stored as a string which blocks introducing a `type: ts.TypeSpecification` attribute in all `itir.Node`s. In order to keep the PR for the new type inference easy to review this has been factored out. ```python class Literal(Expr): value: str type: str @datamodels.validator("type") def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if value not in TYPEBUILTINS: raise ValueError(f"'{value}' is not a valid builtin type.") ``` is changed to ```python class Literal(Expr): value: str type: ts.ScalarType ``` --- src/gt4py/next/ffront/past_to_itir.py | 4 +- src/gt4py/next/iterator/ir.py | 8 +--- src/gt4py/next/iterator/ir_utils/ir_makers.py | 25 +++++++----- src/gt4py/next/iterator/pretty_parser.py | 4 +- .../iterator/transforms/collapse_tuple.py | 3 +- .../next/iterator/transforms/inline_lifts.py | 2 +- src/gt4py/next/iterator/type_inference.py | 8 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 7 ++-- .../runners/dace_iterator/itir_to_sdfg.py | 4 +- src/gt4py/next/type_system/type_info.py | 21 +++++++++- .../ffront_tests/test_past_to_itir.py | 19 +++++++-- .../iterator_tests/test_pretty_parser.py | 11 ++--- .../iterator_tests/test_pretty_printer.py | 15 +++---- .../iterator_tests/test_type_inference.py | 40 +++++++++---------- .../test_collapse_list_get.py | 11 +++-- .../transforms_tests/test_global_tmps.py | 16 ++++---- .../transforms_tests/test_inline_into_scan.py | 5 ++- .../test_scan_eta_reduction.py | 5 ++- .../test_simple_inline_heuristic.py | 5 ++- .../transforms_tests/test_unroll_reduce.py | 9 +++-- .../gtfn_tests/test_gtfn_module.py | 7 ++-- 21 files changed, 134 insertions(+), 95 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a7e9751c4e..fb5c1a6882 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -328,7 +328,7 @@ def _construct_itir_domain_arg( else: lower = self._visit_slice_bound( slices[dim_i].lower if slices else None, - itir.Literal(value="0", type=itir.INTEGER_INDEX_BUILTIN), + im.literal("0", itir.INTEGER_INDEX_BUILTIN), dim_size, ) upper = self._visit_slice_bound( @@ -458,7 +458,7 @@ def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: f"Scalars of kind '{node.type.kind}' not supported currently." ) typename = node.type.kind.name.lower() - return itir.Literal(value=str(node.value), type=typename) + return im.literal(str(node.value), typename) raise NotImplementedError("Only scalar literals supported currently.") diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 5d1907b26f..538ac84cb8 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -20,6 +20,7 @@ from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable +from gt4py.next.type_system import type_specifications as ts # TODO(havogt): @@ -73,12 +74,7 @@ class Expr(Node): ... class Literal(Expr): value: str - type: str - - @datamodels.validator("type") - def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if value not in TYPEBUILTINS: - raise ValueError(f"'{value}' is not a valid builtin type.") + type: ts.ScalarType class NoneLiteral(Expr): diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 8e505be0ec..7fe05594ad 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -64,7 +64,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti SymRef(id=SymbolRef('a')) >>> ensure_expr(3) - Literal(value='3', type='int32') + Literal(value='3', type=ScalarType(kind=, shape=None)) >>> ensure_expr(itir.OffsetLiteral(value="i")) OffsetLiteral(value='i') @@ -94,6 +94,13 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of return str_or_offset +def ensure_type(type_: str | ts.TypeSpec | None) -> ts.TypeSpec | None: + if isinstance(type_, str): + return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper())) + assert isinstance(type_, ts.TypeSpec) or type_ is None + return type_ + + class lambda_: """ Create a lambda from params and an expression. @@ -118,7 +125,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')]) + FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) """ def __init__(self, expr): @@ -291,8 +298,8 @@ def shift(offset, value=None): return call(call("shift")(*args)) -def literal(value: str, typename: str): - return itir.Literal(value=value, type=typename) +def literal(value: str, typename: str) -> itir.Literal: + return itir.Literal(value=value, type=ensure_type(typename)) def literal_from_value(val: core_defs.Scalar) -> itir.Literal: @@ -300,13 +307,13 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.0) - Literal(value='1.0', type='float64') + Literal(value='1.0', type=ScalarType(kind=, shape=None)) >>> literal_from_value(1) - Literal(value='1', type='int32') + Literal(value='1', type=ScalarType(kind=, shape=None)) >>> literal_from_value(2147483648) - Literal(value='2147483648', type='int64') + Literal(value='2147483648', type=ScalarType(kind=, shape=None)) >>> literal_from_value(True) - Literal(value='True', type='bool') + Literal(value='True', type=ScalarType(kind=, shape=None)) """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.") @@ -321,7 +328,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: typename = type_spec.kind.name.lower() assert typename in itir.TYPEBUILTINS - return itir.Literal(value=str(val), type=typename) + return literal(str(val), typename) def neighbors(offset, it): diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index f6d532ee30..3b7a2522a1 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -102,14 +102,14 @@ def SYM(self, value: lark_lexer.Token) -> ir.Sym: def SYM_REF(self, value: lark_lexer.Token) -> Union[ir.SymRef, ir.Literal]: if value.value in ("True", "False"): - return ir.Literal(value=value.value, type="bool") + return im.literal(value.value, "bool") return ir.SymRef(id=value.value) def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal: return im.literal_from_value(int(value.value)) def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal: - return ir.Literal(value=value.value, type="float64") + return im.literal(value.value, "float64") def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: v: Union[int, str] = value.value[:-1] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2bc33e85e1..4b8182a781 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -26,6 +26,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda +from gt4py.next.type_system import type_info class UnknownLength: @@ -232,7 +233,7 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ and isinstance(node.args[0], ir.Literal) ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - assert node.args[0].type in ir.INTEGER_BUILTINS + assert type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) assert idx < len( diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index bf56186253..74ef37fa0c 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -201,7 +201,7 @@ def visit_FunCall( assert len(node.args[0].fun.args) == 1 args = node.args[0].args if len(args) == 0: - return ir.Literal(value="True", type="bool") + return im.literal_from_value(True) res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) for arg in args[1:]: diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 1aae474c4c..89fed49551 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -23,6 +23,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify +from gt4py.next.type_system import type_info """Constraint-based inference for the iterator IR.""" @@ -643,7 +644,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: return TypeVar.fresh() def visit_Literal(self, node: ir.Literal, **kwargs) -> Val: - return Val(kind=Value(), dtype=Primitive(name=node.type)) + return Val(kind=Value(), dtype=Primitive(name=node.type.kind.name.lower())) def visit_AxisLiteral(self, node: ir.AxisLiteral, **kwargs) -> Val: return Val(kind=Value(), dtype=AXIS_DTYPE, size=Scalar()) @@ -672,10 +673,7 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: # Calls to `tuple_get` are handled as being part of the grammar, not as function calls. if len(node.args) != 2: raise TypeError("'tuple_get' requires exactly two arguments.") - if ( - not isinstance(node.args[0], ir.Literal) - or node.args[0].type != ir.INTEGER_INDEX_BUILTIN - ): + if not isinstance(node.args[0], ir.Literal) or not type_info.is_integer(node.args[0].type): raise TypeError( f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'." ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index bbde720677..e5ba965e3b 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -45,6 +45,7 @@ UnstructuredDomain, ) from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef +from gt4py.next.type_system import type_info def pytype_to_cpptype(t: str) -> Optional[str]: @@ -183,7 +184,7 @@ def _collect_offset_definitions( def _literal_as_integral_constant(node: itir.Literal) -> IntegralConstant: - assert node.type in itir.INTEGER_BUILTINS + assert type_info.is_integer(node.type) return IntegralConstant(value=int(node.value)) @@ -193,7 +194,7 @@ def _is_scan(node: itir.Node) -> TypeGuard[itir.FunCall]: def _bool_from_literal(node: itir.Node) -> bool: assert isinstance(node, itir.Literal) - assert node.type == "bool" and node.value in ("True", "False") + assert type_info.is_logical(node.type) and node.value in ("True", "False") return node.value == "True" @@ -296,7 +297,7 @@ def visit_Lambda( ) def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> Literal: - return Literal(value=node.value, type=node.type) + return Literal(value=node.value, type=node.type.kind.name.lower()) def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> OffsetLiteral: return OffsetLiteral(value=node.value) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 487404b207..7a1e9bc4fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -22,7 +22,7 @@ from gt4py.next.common import Connectivity from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.type_system import type_specifications as ts, type_translation as tt +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt from .itir_to_tasklet import ( Context, @@ -64,7 +64,7 @@ def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: """ stencil_fobj = cast(FunCall, stencil) is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and is_forward.type == "bool" + assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) init_carry = stencil_fobj.args[2] assert isinstance(init_carry, Literal) return is_forward.value == "True", init_carry diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index b235e6f26d..a05b9afde8 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -223,6 +223,25 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] +def is_integer(symbol_type: ts.TypeSpec) -> bool: + """ + Check if ``symbol_type`` is an integral type. + + Examples: + --------- + >>> is_integer(ts.ScalarType(kind=ts.ScalarKind.INT32)) + True + >>> is_integer(ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + False + >>> is_integer(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) + False + """ + return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in { + ts.ScalarKind.INT32, + ts.ScalarKind.INT64, + } + + def is_integral(symbol_type: ts.TypeSpec) -> bool: """ Check if the dtype of ``symbol_type`` is an integral type. @@ -236,7 +255,7 @@ def is_integral(symbol_type: ts.TypeSpec) -> bool: >>> is_integral(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] + return is_integer(extract_dtype(symbol_type)) def is_number(symbol_type: ts.TypeSpec) -> bool: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index 49c5b11b20..3d296b6377 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -23,6 +23,7 @@ from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts from next_tests.past_common_fixtures import ( IDim, @@ -59,7 +60,7 @@ def test_copy_lowering(copy_program_def, itir_identity_fundef): fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type="int32"), + P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), ], ) @@ -118,8 +119,20 @@ def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef) fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="1", type=itir.INTEGER_INDEX_BUILTIN), - P(itir.Literal, value="2", type=itir.INTEGER_INDEX_BUILTIN), + P( + itir.Literal, + value="1", + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), + ), + P( + itir.Literal, + value="2", + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), + ), ], ) ], diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index b02c610aff..a2e9a8ada2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.pretty_parser import pparse +from gt4py.next.iterator.ir_utils import ir_makers as im def test_symref(): @@ -41,14 +42,14 @@ def test_arithmetic(): ir.FunCall( fun=ir.SymRef(id="plus"), args=[ - ir.Literal(value="1", type="int32"), - ir.Literal(value="2", type="int32"), + im.literal("1", "int32"), + im.literal("2", "int32"), ], ), - ir.Literal(value="3", type="int32"), + im.literal("3", "int32"), ], ), - ir.Literal(value="4", type="int32"), + im.literal("4", "int32"), ], ) actual = pparse(testee) @@ -115,7 +116,7 @@ def test_tuple_get(): testee = "x[42]" expected = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[ir.Literal(value="42", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index bc1372cea6..44308473b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat +from gt4py.next.iterator.ir_utils import ir_makers as im def test_hmerge(): @@ -111,14 +112,14 @@ def test_arithmetic(): ir.FunCall( fun=ir.SymRef(id="plus"), args=[ - ir.Literal(value="1", type="int64"), - ir.Literal(value="2", type="int64"), + im.literal("1", "int64"), + im.literal("2", "int64"), ], ), - ir.Literal(value="3", type="int64"), + im.literal("3", "int64"), ], ), - ir.Literal(value="4", type="int64"), + im.literal("4", "int64"), ], ) expected = "(1 + 2) × 3 / 4" @@ -132,11 +133,11 @@ def test_associativity(): args=[ ir.FunCall( fun=ir.SymRef(id="plus"), - args=[ir.Literal(value="1", type="int64"), ir.Literal(value="2", type="int64")], + args=[im.literal("1", "int64"), im.literal("2", "int64")], ), ir.FunCall( fun=ir.SymRef(id="plus"), - args=[ir.Literal(value="3", type="int64"), ir.Literal(value="4", type="int64")], + args=[im.literal("3", "int64"), im.literal("4", "int64")], ), ], ) @@ -204,7 +205,7 @@ def test_shift(): def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[ir.Literal(value="42", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) expected = "x[42]" actual = pformat(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7beda20d31..731c163343 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -80,7 +80,7 @@ def test_sym_ref(): def test_bool_literal(): - testee = ir.Literal(value="False", type="bool") + testee = im.literal_from_value(False) expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected @@ -88,7 +88,7 @@ def test_bool_literal(): def test_int_literal(): - testee = ir.Literal(value="3", type="int32") + testee = im.literal("3", "int32") expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected @@ -96,7 +96,7 @@ def test_int_literal(): def test_float_literal(): - testee = ir.Literal(value="3.0", type="float64") + testee = im.literal("3.0", "float64") expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="float64"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected @@ -223,7 +223,7 @@ def test_and(): def test_cast(): testee = ir.FunCall( fun=ir.SymRef(id="cast_"), - args=[ir.Literal(value="1.", type="float64"), ir.SymRef(id="int64")], + args=[im.literal("1.", "float64"), ir.SymRef(id="int64")], ) expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int64"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) @@ -342,8 +342,8 @@ def test_make_tuple(): testee = ir.FunCall( fun=ir.SymRef(id="make_tuple"), args=[ - ir.Literal(value="True", type="bool"), - ir.Literal(value="42.0", type="float64"), + im.literal("True", "bool"), + im.literal("42.0", "float64"), ir.SymRef(id="x"), ], ) @@ -363,12 +363,12 @@ def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), args=[ - ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("1", ir.INTEGER_INDEX_BUILTIN), ir.FunCall( fun=ir.SymRef(id="make_tuple"), args=[ - ir.Literal(value="True", type="bool"), - ir.Literal(value="42.0", type="float64"), + im.literal("True", "bool"), + im.literal("42.0", "float64"), ], ), ], @@ -384,7 +384,7 @@ def test_tuple_get_in_lambda(): params=[ir.Sym(id="x")], expr=ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("1", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ), ) expected = ti.FunctionType( @@ -449,9 +449,7 @@ def test_reduce(): ], ), ) - testee = ir.FunCall( - fun=ir.SymRef(id="reduce"), args=[reduction_f, ir.Literal(value="0", type="int64")] - ) + testee = ir.FunCall(fun=ir.SymRef(id="reduce"), args=[reduction_f, im.literal("0", "int64")]) expected = ti.FunctionType( args=ti.ValListTuple( kind=ti.Value(), @@ -486,7 +484,7 @@ def test_scan(): ) testee = ir.FunCall( fun=ir.SymRef(id="scan"), - args=[scan_f, ir.Literal(value="True", type="bool"), ir.Literal(value="0", type="int64")], + args=[scan_f, im.literal("True", "bool"), im.literal("0", "int64")], ) expected = ti.FunctionType( args=ti.Tuple.from_elems( @@ -697,7 +695,7 @@ def test_dynamic_offset(): fun=ir.SymRef(id="named_range"), args=[ ir.AxisLiteral(value="IDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="i"), ], ), @@ -705,7 +703,7 @@ def test_dynamic_offset(): fun=ir.SymRef(id="named_range"), args=[ ir.AxisLiteral(value="JDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="j"), ], ), @@ -713,7 +711,7 @@ def test_dynamic_offset(): fun=ir.SymRef(id="named_range"), args=[ ir.AxisLiteral(value="KDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="k"), ], ), @@ -839,8 +837,8 @@ def test_fencil_definition_same_closure_input(): domain=im.call("unstructured_domain")( im.call("named_range")( ir.AxisLiteral(value="Edge"), - ir.Literal(value="0", type="int32"), - ir.Literal(value="10", type="int32"), + im.literal("0", "int32"), + im.literal("10", "int32"), ) ), stencil=im.ref("f1"), @@ -851,8 +849,8 @@ def test_fencil_definition_same_closure_input(): domain=im.call("unstructured_domain")( im.call("named_range")( ir.AxisLiteral(value="Vertex"), - ir.Literal(value="0", type="int32"), - ir.Literal(value="10", type="int32"), + im.literal("0", "int32"), + im.literal("10", "int32"), ) ), stencil=im.ref("f2"), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py index b6463ba0d5..87ed414393 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet +from gt4py.next.iterator.ir_utils import ir_makers as im def _list_get(index: ir.Expr, lst: ir.Expr) -> ir.FunCall: @@ -26,7 +27,7 @@ def _neighbors(offset: ir.Expr, it: ir.Expr) -> ir.FunCall: def test_list_get_neighbors(): testee = _list_get( - ir.Literal(value="42", type="int32"), + im.literal("42", "int32"), _neighbors(ir.OffsetLiteral(value="foo"), ir.SymRef(id="bar")), ) @@ -49,13 +50,11 @@ def test_list_get_neighbors(): def test_list_get_make_const_list(): testee = _list_get( - ir.Literal(value="42", type="int32"), - ir.FunCall( - fun=ir.SymRef(id="make_const_list"), args=[ir.Literal(value="3.14", type="float64")] - ), + im.literal("42", "int32"), + ir.FunCall(fun=ir.SymRef(id="make_const_list"), args=[im.literal("3.14", "float64")]), ) - expected = ir.Literal(value="3.14", type="float64") + expected = im.literal("3.14", "float64") actual = CollapseListGet().visit(testee) assert expected == actual diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 442f71bc9b..797ed2a703 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -241,7 +241,7 @@ def test_update_cartesian_domains(): *( im.call("named_range")( ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ) for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) @@ -269,7 +269,7 @@ def test_update_cartesian_domains(): im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - im.plus(im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)), + im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), ], ) ] @@ -296,7 +296,7 @@ def test_update_cartesian_domains(): im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - im.plus(im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)), + im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), ], ) ] @@ -305,7 +305,7 @@ def test_update_cartesian_domains(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ], ) @@ -324,10 +324,10 @@ def test_collect_tmps_info(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value="IDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.FunCall( fun=im.ref("plus"), - args=[im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)], + args=[im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)], ), ], ) @@ -337,7 +337,7 @@ def test_collect_tmps_info(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ], ) @@ -381,7 +381,7 @@ def test_collect_tmps_info(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py index 50a3b3ecab..c8a61a3b2f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan +from gt4py.next.iterator.ir_utils import ir_makers as im # TODO(havogt): remove duplication with test_eta_reduction @@ -25,8 +26,8 @@ def _make_scan(*args: list[str], scanpass_body: ir.Expr) -> ir.Expr: params=[ir.Sym(id="state")] + [ir.Sym(id=f"{arg}") for arg in args], expr=scanpass_body, ), - ir.Literal(value="0.0", type="float64"), - ir.Literal(value="True", type="bool"), + im.literal("0.0", "float64"), + im.literal("True", "bool"), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py index 5a9d3a676b..53678d278e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction +from gt4py.next.iterator.ir_utils import ir_makers as im def _make_scan(*args: list[str]): @@ -24,8 +25,8 @@ def _make_scan(*args: list[str]): params=[ir.Sym(id="state")] + [ir.Sym(id=f"{arg}") for arg in args], expr=ir.SymRef(id="foo"), ), - ir.Literal(value="0.0", type="float64"), - ir.Literal(value="True", type="bool"), + im.literal("0.0", "float64"), + im.literal("True", "bool"), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py index 685625e9e7..e236b7dd49 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.simple_inline_heuristic import is_eligible_for_inlining +from gt4py.next.iterator.ir_utils import ir_makers as im @pytest.fixture @@ -33,8 +34,8 @@ def scan(): ], ), ), - ir.Literal(value="True", type="bool"), - ir.Literal(value="0.0", type="float64"), + im.literal("True", "bool"), + im.literal("0.0", "float64"), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index ba4a91e6b5..054e7fac12 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -19,6 +19,7 @@ from gt4py.eve.utils import UIDs from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags +from gt4py.next.iterator.ir_utils import ir_makers as im from next_tests.unit_tests.conftest import DummyConnectivity @@ -34,7 +35,7 @@ def basic_reduction(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.FunCall( @@ -51,7 +52,7 @@ def reduction_with_shift_on_second_arg(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.SymRef(id="x"), @@ -69,7 +70,7 @@ def reduction_with_incompatible_shifts(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.FunCall( @@ -90,7 +91,7 @@ def reduction_with_irrelevant_full_shift(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.FunCall( diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index be7a9ff81e..d3651d3084 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -19,6 +19,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.otf import languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.iterator.ir_utils import ir_makers as im @pytest.fixture @@ -30,8 +31,8 @@ def fencil_example(): fun=itir.SymRef(id="named_range"), args=[ itir.AxisLiteral(value="X"), - itir.Literal(value="0", type=itir.INTEGER_INDEX_BUILTIN), - itir.Literal(value="10", type=itir.INTEGER_INDEX_BUILTIN), + im.literal("0", itir.INTEGER_INDEX_BUILTIN), + im.literal("10", itir.INTEGER_INDEX_BUILTIN), ], ) ], @@ -43,7 +44,7 @@ def fencil_example(): itir.FunctionDefinition( id="stencil", params=[itir.Sym(id="buf"), itir.Sym(id="sc")], - expr=itir.Literal(value="1", type="float64"), + expr=im.literal("1", "float64"), ) ], closures=[