diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 029833cb7d..358f6e8d0d 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -15,7 +15,7 @@ from gt4py import eve ```python cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( - past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False) + past_to_itir=gtx.ffront.past_to_itir.past_to_gtir_factory(cached=False) ) ``` diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e223d7771c..e075422ca3 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -16,7 +16,6 @@ from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( foast_to_gtir, - foast_to_itir, foast_to_past, func_to_foast, func_to_past, @@ -41,7 +40,7 @@ ARGS: typing.TypeAlias = arguments.JITArgs CARG: typing.TypeAlias = arguments.CompileTimeArgs -IT_PRG: typing.TypeAlias = itir.FencilDefinition | itir.Program +IT_PRG: typing.TypeAlias = itir.Program INPUT_DATA: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PRG | IT_PRG @@ -93,7 +92,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field( - default_factory=past_to_itir.past_to_itir_factory + default_factory=past_to_itir.past_to_gtir_factory ) def step_order(self, inp: INPUT_PAIR) -> list[str]: @@ -126,7 +125,7 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: ) case PRG(): steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"]) - case itir.FencilDefinition() | itir.Program(): + case itir.Program(): pass case _: raise ValueError("Unexpected input.") @@ -135,17 +134,6 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() -# FIXME[#1582](havogt): remove after refactoring to GTIR -# note: this step is deliberately placed here, such that the cache is shared -_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) -LEGACY_TRANSFORMS: Transforms = Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), - foast_to_itir=_foast_to_itir_step, - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=_foast_to_itir_step - ), -) - # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 61756f30c9..d187095019 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,7 +34,6 @@ from gt4py.next.ffront import ( field_operator_ast as foast, foast_to_gtir, - foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -186,7 +185,7 @@ def _all_closure_vars(self) -> dict[str, Any]: return transform_utils._get_closure_vars_recursively(self.past_stage.closure_vars) @functools.cached_property - def itir(self) -> itir.FencilDefinition: + def gtir(self) -> itir.Program: no_args_past = toolchain.CompilableProgram( data=ffront_stages.PastProgramDefinition( past_node=self.past_stage.past_node, @@ -561,7 +560,7 @@ def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: # a different backend than the one of the program that calls this field operator. Just use # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return foast_to_itir.foast_to_itir(self.foast_stage) + return foast_to_gtir.foast_to_gtir(self.foast_stage) # FIXME[#1582](tehrengruber): remove after refactoring to GTIR def __gt_gtir__(self) -> itir.FunctionDefinition: diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py deleted file mode 100644 index 538b0f3ddb..0000000000 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ /dev/null @@ -1,512 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# FIXME[#1582](havogt): remove after refactoring to GTIR - -import dataclasses -from typing import Any, Callable, Optional - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Never -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.ffront import ( - dialect_ast_enums, - fbuiltins, - field_operator_ast as foast, - lowering_utils, - stages as ffront_stages, - type_specifications as ts_ffront, -) -from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES -from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.ffront.stages import AOT_FOP, FOP -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts - - -def foast_to_itir(inp: FOP) -> itir.Expr: - """ - Lower a FOAST field operator node to Iterator IR. - - See the docstring of `FieldOperatorLowering` for details. - """ - return FieldOperatorLowering.apply(inp.foast_node) - - -def foast_to_itir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: - """Wrap `foast_to_itir` into a chainable and, optionally, cached workflow step.""" - wf = foast_to_itir - if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) - return wf - - -def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: - """Wrap the `foast_to_itir` workflow step into an adapter to fit into backend transform workflows.""" - return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) - - -def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node_type): - return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) - return lambda x: x - - -@dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): - """ - Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). - - The strategy is to lower every expression to lifted stencils, - i.e. taking iterators and returning iterator. - - Examples - -------- - >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser - >>> from gt4py.next import Field, Dimension, float64 - >>> - >>> IDim = Dimension("IDim") - >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp - >>> - >>> parsed = FieldOperatorParser.apply_to_function(fieldop) - >>> lowered = FieldOperatorLowering.apply(parsed) - >>> type(lowered) - - >>> lowered.id - SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS - [Sym(id=SymbolName('inp'))] - """ - - uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) - - @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) - - def visit_FunctionDefinition( - self, node: foast.FunctionDefinition, **kwargs: Any - ) -> itir.FunctionDefinition: - params = self.visit(node.params) - return itir.FunctionDefinition( - id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) - ) # `expr` is a lifted stencil - - def visit_FieldOperator( - self, node: foast.FieldOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - - new_body = func_definition.expr - - return itir.FunctionDefinition( - id=func_definition.id, params=func_definition.params, expr=new_body - ) - - def visit_ScanOperator( - self, node: foast.ScanOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - # note: we don't need the axis here as this is handled by the program - # decorator - assert isinstance(node.type, ts_ffront.ScanOperatorType) - - # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. - # In iterator IR we didn't properly specify if this is legal, - # however after lift-inlining the expressions are transformed back to literals. - forward = im.deref(self.visit(node.forward, **kwargs)) - init = lowering_utils.process_elements( - im.deref, self.visit(node.init, **kwargs), node.init.type - ) - - # lower definition function - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = im.let( - func_definition.params[0].id, - # promote carry to iterator of tuples - # (this is the only place in the lowering were a variable is captured in a lifted lambda) - lowering_utils.to_tuples_of_iterator( - im.promote_to_const_iterator(func_definition.params[0].id), - [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] - ), - )( - # the function itself returns a tuple of iterators, deref element-wise - lowering_utils.process_elements( - im.deref, func_definition.expr, node.type.definition.returns - ) - ) - - stencil_args: list[itir.Expr] = [] - assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args - for param, arg_type in zip( - func_definition.params[1:], - [*node.type.definition.pos_or_kw_args.values()][1:], - strict=True, - ): - if isinstance(arg_type, ts.TupleType): - # convert into iterator of tuples - stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) - - new_body = im.let( - param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) - )(new_body) - else: - stencil_args.append(im.ref(param.id)) - - definition = itir.Lambda(params=func_definition.params, expr=new_body) - - body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - - return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) - - def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: - raise AssertionError("Statements must always be visited in the context of a function.") - - def visit_Return( - self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return self.visit(node.value, **kwargs) - - def visit_BlockStmt( - self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - for stmt in reversed(node.stmts): - inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) - assert inner_expr - return inner_expr - - def visit_IfStmt( - self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) - assert ( - isinstance(node.condition.type, ts.ScalarType) - and node.condition.type.kind == ts.ScalarKind.BOOL - ) - - cond = self.visit(node.condition, **kwargs) - - return_kind: StmtReturnKind = deduce_stmt_return_kind(node) - - common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols - - if return_kind is StmtReturnKind.NO_RETURN: - # pack the common symbols into a tuple - common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) - - # apply both branches and extract the common symbols through the prepared tuple - true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) - - # unpack the common symbols' tuple for `inner_expr` - for i, sym in enumerate(common_symbols.keys()): - inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) - - # here we assume neither branch returns - return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - inner_expr - ) - elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: - common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) - common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) - - # wrap the inner expression in a lambda function. note that this increases the - # operation count if both branches are evaluated. - inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") - inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) - inner_expr = im.call(inner_expr_name)(*common_symrefs) - - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.let(inner_expr_name, inner_expr_evaluator)( - im.if_(im.deref(cond), true_branch, false_branch) - ) - - assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN - - # note that we do not duplicate `inner_expr` here since if both branches - # return, `inner_expr` is ignored. - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.if_(im.deref(cond), true_branch, false_branch) - - def visit_Assign( - self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( - inner_expr - ) - - def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - return im.sym(node.id) - - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: - return im.ref(node.id) - - def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: - return im.tuple_get(node.index, self.visit(node.value, **kwargs)) - - def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: - return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: - # TODO(tehrengruber): extend iterator ir to support unary operators - dtype = type_info.extract_dtype(node.type) - if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: - if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._lower_and_map("not_", node.operand) - - return self._lower_and_map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) - - def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: - op = "if_" - args = (node.condition, node.true_expr, node.false_expr) - lowered_args: list[itir.Expr] = [ - lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) - for arg in args - ] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [ - promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args) - ] - op = im.call("map_")(op) - - return lowering_utils.to_tuples_of_iterator( - im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type - ) - - def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - current_expr = self.visit(node.func, **kwargs) - - for arg in node.args: - match arg: - # `field(Off[idx])` - case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): - current_expr = im.lift( - im.lambda_("it")(im.deref(im.shift(offset_name, offset_index)("it"))) - )(current_expr) - # `field(Dim + idx)` - case foast.BinOp( - op=dialect_ast_enums.BinaryOperator.ADD - | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), - right=foast.Constant(value=offset_index), - ): - if arg.op == dialect_ast_enums.BinaryOperator.SUB: - offset_index *= -1 - current_expr = im.lift( - # TODO(SF-N): we rely on the naming-convention that the cartesian dimensions - # are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the - # offset provider. This is a rather unclean solution and should be - # improved. - im.lambda_("it")( - im.deref( - im.shift( - common.dimension_to_implicit_offset(dimension), offset_index - )("it") - ) - ) - )(current_expr) - # `field(Off)` - case foast.Name(id=offset_name): - # only a single unstructured shift is supported so returning here is fine even though we - # are in a loop. - assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern - return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # `field(as_offset(Off, offset_field))` - case foast.Call(func=foast.Name(id="as_offset")): - func_args = arg - # TODO(tehrengruber): Use type system to deduce the offset dimension instead of - # (e.g. to allow aliasing) - offset_dim = func_args.args[0] - assert isinstance(offset_dim, foast.Name) - offset_it = self.visit(func_args.args[1], **kwargs) - current_expr = im.lift( - im.lambda_("it", "offset")( - im.deref(im.shift(offset_dim.id, im.deref("offset"))("it")) - ) - )(current_expr, offset_it) - case _: - raise FieldOperatorLoweringError("Unexpected shift arguments!") - - return current_expr - - def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if type_info.type_class(node.func.type) is ts.FieldType: - return self._visit_shift(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in ( - FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES - ): - visitor = getattr(self, f"_visit_{node.func.id}") - return visitor(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: - return self._visit_type_constr(node, **kwargs) - elif isinstance( - node.func.type, - (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - ): - # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # and keyword arguments and use the unique order as given in the function signature. - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - node.func.type, - self.visit(node.args, **kwargs), - self.visit(node.kwargs, **kwargs), - use_signature_ordering=True, - ) - result = im.call(self.visit(node.func, **kwargs))( - *lowered_args, *lowered_kwargs.values() - ) - - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - result = lowering_utils.to_tuples_of_iterator( - result, node.func.type.definition.returns - ) - - return result - - raise AssertionError( - f"Call to object of type '{type(node.func.type).__name__}' not understood." - ) - - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, new_type = node.args[0], node.args[1].id - return lowering_utils.process_elements( - lambda x: im.promote_to_lifted_stencil( - im.lambda_("it")(im.call("cast_")("it", str(new_type))) - )(x), - self.visit(obj, **kwargs), - obj.type, - ) - - def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - condition, true_value, false_value = node.args - - lowered_condition = self.visit(condition, **kwargs) - return lowering_utils.process_elements( - lambda tv, fv, types: _map( - "if_", (lowered_condition, tv, fv), (condition.type, *types) - ), - [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], - node.type, - (node.args[1].type, node.args[2].type), - ) - - _visit_concat_where = _visit_where - - def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) - - def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) - - def _make_reduction_expr( - self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any - ) -> itir.Expr: - # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) - it = self.visit(node.args[0], **kwargs) - assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return im.promote_to_lifted_stencil(val)(it) - - def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) - - def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - min_value, _ = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(min_value), dtype) - return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) - - def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - _, max_value = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(max_value), dtype) - return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) - - def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - el = node.args[0] - node_kind = self.visit(node.type).kind.name.lower() - source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] - target_type = fbuiltins.BUILTINS[node_kind] - - if isinstance(el, foast.Constant): - val = source_type(el.value) - elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): - operand = source_type(el.operand.value) - val = eval(f"lambda arg: {el.op}arg")(operand) - else: - raise FieldOperatorLoweringError( - f"Type cast only supports literal arguments, {node.type} not supported." - ) - val = target_type(val) - - return im.promote_to_const_iterator(im.literal(str(val), node_kind)) - - def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # the following constructs work if they are removed by inlining. - if isinstance(type_, ts.TupleType): - return im.make_tuple( - *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - ) - elif isinstance(type_, ts.ScalarType): - typename = type_.kind.name.lower() - return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type '{type_}'.") - - def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: - return self._make_literal(node.value, node.type) - - def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - return _map( - op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) - ) - - -def _map( - op: itir.Expr | str, - lowered_args: tuple, - original_arg_types: tuple[ts.TypeSpec, ...], -) -> itir.FunCall: - """ - Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists. - """ - if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): - lowered_args = tuple( - promote_to_list(arg_type)(larg) - for arg_type, larg in zip(original_arg_types, lowered_args) - ) - op = im.call("map_")(op) - - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - - -class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c0348bb5c6..4ec12bb76b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import functools from typing import Any, Optional, cast import devtools @@ -19,7 +18,6 @@ from gt4py.next.ffront import ( fbuiltins, gtcallable, - lowering_utils, program_ast as past, stages as ffront_stages, transform_utils, @@ -32,10 +30,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove # dependency as soon as column axis can be deduced from ITIR in consumers of the CompilableProgram. -def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgram: +def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: """ Lower a PAST program definition to Iterator IR. @@ -59,7 +56,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ... column_axis=None, ... ) - >>> itir_copy = past_to_itir( + >>> itir_copy = past_to_gtir( ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) ... ) @@ -67,7 +64,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra copy_program >>> print(type(itir_copy.data)) - + """ all_closure_vars = transform_utils._get_closure_vars_recursively(inp.data.closure_vars) offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( @@ -88,13 +85,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # making this step aware of the toolchain it is called by (it can be part of multiple). lowered_funcs = [] for gt_callable in gt_callables: - if to_gtir: - lowered_funcs.append(gt_callable.__gt_gtir__()) - else: - lowered_funcs.append(gt_callable.__gt_itir__()) + lowered_funcs.append(gt_callable.__gt_gtir__()) itir_program = ProgramLowering.apply( - inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir + inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) if config.DEBUG or inp.data.debug: @@ -106,11 +100,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ) -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR -def past_to_itir_factory( - cached: bool = True, to_gtir: bool = True +def past_to_gtir_factory( + cached: bool = True, ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: - wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) + wf = workflow.make_step(past_to_gtir) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) return wf @@ -190,7 +183,7 @@ class ProgramLowering( ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP - + >>> lowered.id # doctest: +SKIP SymbolName('program') >>> lowered.params # doctest: +SKIP @@ -198,7 +191,6 @@ class ProgramLowering( """ grid_type: common.GridType - to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -209,11 +201,8 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, - to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR - ) -> itir.FencilDefinition: - return cls(grid_type=grid_type, to_gtir=to_gtir).visit( - node, function_definitions=function_definitions - ) + ) -> itir.Program: + return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: """Generate symbols for each field param and dimension.""" @@ -246,7 +235,7 @@ def visit_Program( *, function_definitions: list[itir.FunctionDefinition], **kwargs: Any, - ) -> itir.FencilDefinition | itir.Program: + ) -> itir.Program: # The ITIR does not support dynamically getting the size of a field. As # a workaround we add additional arguments to the fencil definition # containing the size of all fields. The caller of a program is (e.g. @@ -259,27 +248,17 @@ def visit_Program( params = params + self._gen_size_params_from_program(node) implicit_domain = True - if self.to_gtir: - set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] - return itir.Program( - id=node.id, - function_definitions=function_definitions, - params=params, - declarations=[], - body=set_ats, - implicit_domain=implicit_domain, - ) - else: - closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] - return itir.FencilDefinition( - id=node.id, - function_definitions=function_definitions, - params=params, - closures=closures, - implicit_domain=implicit_domain, - ) + set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] + return itir.Program( + id=node.id, + function_definitions=function_definitions, + params=params, + declarations=[], + body=set_ats, + implicit_domain=implicit_domain, + ) - def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: + def _visit_field_operator_call(self, node: past.Call, **kwargs: Any) -> itir.SetAt: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) @@ -303,56 +282,6 @@ def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir. target=output, ) - # FIXME[#1582](havogt): remove after refactoring to GTIR - def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: - assert isinstance(node.kwargs["out"].type, ts.TypeSpec) - assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) - - node_kwargs = {**node.kwargs} - domain = node_kwargs.pop("domain", None) - output, lowered_domain = self._visit_stencil_call_out_arg( - node_kwargs.pop("out"), domain, **kwargs - ) - - assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) - - args, node_kwargs = type_info.canonicalize_arguments( - node.func.type, node.args, node_kwargs, use_signature_ordering=True - ) - - lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) - else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) - - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, - ) - - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, - ) - def _visit_slice_bound( self, slice_bound: Optional[past.Constant], diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 6efee29362..e875709631 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -9,7 +9,7 @@ from typing import ClassVar, List, Optional, Union import gt4py.eve as eve -from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve import Coerced, SymbolName, SymbolRef from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @@ -19,10 +19,6 @@ DimensionKind = common.DimensionKind -# TODO(havogt): -# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. -# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. - @noninstantiable class Node(eve.Node): @@ -97,23 +93,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -class StencilClosure(Node): - domain: FunCall - stencil: Expr - output: Union[SymRef, FunCall] - inputs: List[Union[SymRef, FunCall]] - - @datamodels.validator("output") - def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to 'make_tuple' allowed.") - - @datamodels.validator("inputs") - def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value): - raise ValueError("Only FunCall to 'index' allowed.") - - UNARY_MATH_NUMBER_BUILTINS = {"abs"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -195,18 +174,6 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu } -class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: Coerced[SymbolName] - function_definitions: List[FunctionDefinition] - params: List[Sym] - closures: List[StencilClosure] - implicit_domain: bool = False - - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(BUILTINS) - ] # sorted for serialization stability - - class Stmt(Node): ... @@ -252,8 +219,6 @@ class Program(Node, ValidatedSymbolTableTrait): Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] FunctionDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] -StencilClosure.__hash__ = Node.__hash__ # type: ignore[method-assign] -FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index b4a673772f..29b30beae1 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -216,10 +216,6 @@ def function_definition(self, *args: ir.Node) -> ir.FunctionDefinition: fid, *params, expr = args return ir.FunctionDefinition(id=fid, params=params, expr=expr) - def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: - output, stencil, *inputs, domain = args - return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - def if_stmt(self, cond: ir.Expr, *args): found_else_seperator = False true_branch = [] @@ -249,23 +245,6 @@ def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) - # TODO(havogt): remove after refactoring. - def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: - params = [] - function_definitions = [] - closures = [] - for arg in args: - if isinstance(arg, ir.Sym): - params.append(arg) - elif isinstance(arg, ir.FunctionDefinition): - function_definitions.append(arg) - else: - assert isinstance(arg, ir.StencilClosure) - closures.append(arg) - return ir.FencilDefinition( - id=fid, function_definitions=function_definitions, params=params, closures=closures - ) - def program(self, fid: str, *args: ir.Node) -> ir.Program: params = [] function_definitions = [] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 99287f8a11..a25f99356c 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -248,28 +248,6 @@ def visit_FunctionDefinition(self, node: ir.FunctionDefinition, prec: int) -> li vbody = self._vmerge(params, self._indent(expr)) return self._optimum(hbody, vbody) - def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[str]: - assert prec == 0 - domain = self.visit(node.domain, prec=0) - stencil = self.visit(node.stencil, prec=0) - output = self.visit(node.output, prec=0) - inputs = self.visit(node.inputs, prec=0) - - hinputs = self._hmerge(["("], *self._hinterleave(inputs, ", "), [")"]) - vinputs = self._vmerge(["("], *self._hinterleave(inputs, ",", indent=True), [")"]) - inputs = self._optimum(hinputs, vinputs) - - head = self._hmerge(output, [" ← "]) - foot = self._hmerge(inputs, [" @ "], domain, [";"]) - - h = self._hmerge(head, ["("], stencil, [")"], foot) - v = self._vmerge( - self._hmerge(head, ["("]), - self._indent(self._indent(stencil)), - self._indent(self._hmerge([")"], foot)), - ) - return self._optimum(h, v) - def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: start, end = [node.id + " = temporary("], [");"] args = [] @@ -312,25 +290,6 @@ def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]: head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"] ) - def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: - assert prec == 0 - function_definitions = self.visit(node.function_definitions, prec=0) - closures = self.visit(node.closures, prec=0) - params = self.visit(node.params, prec=0) - - hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) - vparams = self._vmerge( - [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] - ) - params = self._optimum(hparams, vparams) - - function_definitions = self._vmerge(*function_definitions) - closures = self._vmerge(*closures) - - return self._vmerge( - params, self._indent(function_definitions), self._indent(closures), ["}"] - ) - def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 81e9551e5c..12c86680b5 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -258,7 +258,7 @@ def _contains_tuple_dtype_field(arg): return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims) -def _make_fencil_params(fun, args) -> list[Sym]: +def _make_program_params(fun, args) -> list[Sym]: params: list[Sym] = [] param_infos = list(inspect.signature(fun).parameters.values()) @@ -293,18 +293,16 @@ def _make_fencil_params(fun, args) -> list[Sym]: return params -def trace_fencil_definition( - fun: typing.Callable, args: typing.Iterable -) -> itir.FencilDefinition | itir.Program: +def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program: """ - Transform fencil given as a callable into `itir.FencilDefinition` using tracing. + Transform fencil given as a callable into `itir.Program` using tracing. Arguments: - fun: The fencil / callable to trace. + fun: The program / callable to trace. args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type. """ with TracerContext() as _: - params = _make_fencil_params(fun, args) + params = _make_program_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) return itir.Program( diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index aeccb5f26d..d0afc610e7 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -7,10 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator.transforms.pass_manager import ( - ITIRTransform, + GTIRTransform, apply_common_transforms, apply_fieldview_transforms, ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index e71a24127f..b64886f729 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -128,7 +128,7 @@ def apply( flags = flags or cls.flags offset_provider_type = offset_provider_type or {} - if isinstance(node, (ir.Program, ir.FencilDefinition)): + if isinstance(node, ir.Program): within_stencil = False assert within_stencil in [ True, diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 824adfdd8d..4f3fcbfdd5 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -376,7 +376,7 @@ def extract_subexpression( return _NodeReplacer(expr_map).visit(node), extracted, ignored_children -ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr) +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) @dataclasses.dataclass(frozen=True) @@ -413,7 +413,7 @@ def apply( within_stencil: bool | None = None, offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: - is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) + is_program = isinstance(node, itir.Program) if is_program: assert within_stencil is None within_stencil = False diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py deleted file mode 100644 index 4ad91645d4..0000000000 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im - - -class FencilToProgram(eve.NodeTranslator): - @classmethod - def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program: - return cls().visit(node) - - def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs) - return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output) - - def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: - return itir.Program( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - declarations=[], - body=self.visit(node.closures), - implicit_domain=node.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec6f89685a..ec4207d726 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,13 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Optional, Protocol +from typing import Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( - fencil_to_program, fuse_as_fieldop, global_tmps, infer_domain, @@ -32,16 +31,16 @@ from gt4py.next.iterator.type_system.inference import infer -class ITIRTransform(Protocol): +class GTIRTransform(Protocol): def __call__( - self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider + self, _: itir.Program, *, offset_provider: common.OffsetProvider ) -> itir.Program: ... # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, *, offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, @@ -49,10 +48,6 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, @@ -62,9 +57,6 @@ def apply_common_transforms( if offset_provider_type is None: offset_provider_type = common.offset_provider_to_type(offset_provider) - # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this - if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram.apply(ir) assert isinstance(ir, itir.Program) tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") @@ -73,7 +65,7 @@ def apply_common_transforms( ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) # note: this increases the size of the tree @@ -82,7 +74,7 @@ def apply_common_transforms( # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( - ir, # type: ignore[arg-type] # always an itir.Program + ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) @@ -119,7 +111,7 @@ def apply_common_transforms( if extract_temporaries: ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/transforms/program_to_fencil.py b/src/gt4py/next/iterator/transforms/program_to_fencil.py deleted file mode 100644 index 4411dda74f..0000000000 --- a/src/gt4py/next/iterator/transforms/program_to_fencil.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -def program_to_fencil(node: itir.Program) -> itir.FencilDefinition: - assert not node.declarations - closures = [] - for stmt in node.body: - assert isinstance(stmt, itir.SetAt) - assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop") - stencil, domain = stmt.expr.fun.args - inputs = stmt.expr.args - assert all(isinstance(inp, itir.SymRef) for inp in inputs) - closures.append( - itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs) - ) - - return itir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - closures=closures, - ) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py deleted file mode 100644 index 5058a91216..0000000000 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir - - -class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): - """Removes all unused input arguments from a stencil closure.""" - - def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: - if not isinstance(node.stencil, ir.Lambda): - return node - - unused: set[str] = {p.id for p in node.stencil.params} - expr = self.visit(node.stencil.expr, unused=unused, shadowed=set[str]()) - params = [] - inputs = [] - for param, inp in zip(node.stencil.params, node.inputs): - if param.id not in unused: - params.append(param) - inputs.append(inp) - - return ir.StencilClosure( - domain=node.domain, - stencil=ir.Lambda(params=params, expr=expr), - output=node.output, - inputs=inputs, - ) - - def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: - if node.id not in shadowed: - unused.discard(node.id) - return node - - def visit_Lambda(self, node: ir.Lambda, *, unused: set[str], shadowed: set[str]) -> ir.Lambda: - return self.generic_visit( - node, unused=unused, shadowed=shadowed | {p.id for p in node.params} - ) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1765259a81..2903201083 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -69,7 +69,7 @@ def apply( Counter({SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: - inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} + inactive_refs = {str(n.id) for n in itir.Program._NODE_SYMBOLS_} else: inactive_refs = set() diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index ffca6cc7a7..1b980783fa 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -352,7 +352,7 @@ def apply( Preconditions: - All parameters in :class:`itir.Program` and :class:`itir.FencilDefinition` must have a type + All parameters in :class:`itir.Program` must have a type defined, as they are the starting point for type propagation. Design decisions: @@ -401,9 +401,9 @@ def apply( # parts of a program. node = SanitizeTypes().visit(node) - if isinstance(node, (itir.FencilDefinition, itir.Program)): + if isinstance(node, itir.Program): assert all(isinstance(param.type, ts.DataType) for param in node.params), ( - "All parameters in 'itir.Program' and 'itir.FencilDefinition' must have a type " + "All parameters in 'itir.Program' must have a type " "defined, as they are the starting point for type propagation.", ) @@ -460,20 +460,6 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - - function_definitions: dict[str, type_synthesizer.TypeSynthesizer] = {} - for fun_def in node.function_definitions: - function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) - - closures = self.visit(node.closures, ctx=ctx | params | function_definitions) - return it_ts.FencilType(params=params, closures=closures) - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: @@ -532,37 +518,6 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: and target_type.dtype == expr_type.dtype ) - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: - domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) - inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) - output: ts.FieldType = self.visit(node.output, ctx=ctx) - - assert isinstance(domain, it_ts.DomainType) - for output_el in type_info.primitive_constituents(output): - assert isinstance(output_el, ts.FieldType) - - stencil_type_synthesizer = self.visit(node.stencil, ctx=ctx) - stencil_args = [ - type_synthesizer._convert_as_fieldop_input_to_iterator(domain, input_) - for input_ in inputs - ] - stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider_type=self.offset_provider_type - ) - - return it_ts.StencilClosureType( - domain=domain, - stencil=ts.FunctionType( - pos_only_args=stencil_args, - pos_or_kw_args={}, - kw_only_args={}, - returns=stencil_returns, - ), - output=output, - inputs=inputs, - ) - def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: assert ( node.value in self.dimensions diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index edb56f5659..eef8c75d0f 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -43,30 +43,6 @@ class IteratorType(ts.DataType, ts.CallableType): element_type: ts.DataType -@dataclasses.dataclass(frozen=True) -class StencilClosureType(ts.TypeSpec): - domain: DomainType - stencil: ts.FunctionType - output: ts.FieldType | ts.TupleType - inputs: list[ts.FieldType] - - def __post_init__(self): - # local import to avoid importing type_info from a type_specification module - from gt4py.next.type_system import type_info - - for i, el_type in enumerate(type_info.primitive_constituents(self.output)): - assert isinstance( - el_type, ts.FieldType - ), f"All constituent types must be field types, but the {i}-th element is of type '{el_type}'." - - -# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere -@dataclasses.dataclass(frozen=True) -class FencilType(ts.TypeSpec): - params: dict[str, ts.DataType] - closures: list[StencilClosureType] - - @dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 85838d9c76..22326c7e87 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -26,9 +26,7 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -CompilableProgram: TypeAlias = toolchain.CompilableProgram[ - itir.FencilDefinition | itir.Program, arguments.CompileTimeArgs -] +CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index f1649112a7..020b1f55ea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Callable, Final, Optional +from typing import Any, Final, Optional import factory import numpy as np @@ -53,9 +53,6 @@ class GTFNTranslationStep( use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -80,7 +77,7 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting def _process_regular_arguments( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, arg_types: tuple[ts.TypeSpec, ...], offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: @@ -157,7 +154,7 @@ def _process_connectivity_args( def _preprocess_program( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( @@ -167,7 +164,6 @@ def _preprocess_program( # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( @@ -186,7 +182,7 @@ def _preprocess_program( def generate_stencil_source( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: @@ -214,7 +210,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dc0012b041..d5b34fd5b9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -108,7 +108,7 @@ def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index db1242e2a4..5f32eaa2bb 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. gtfn_translation = gtfn.GTFNBackendFactory().executor.translation assert isinstance(gtfn_translation, GTFNTranslationStep) diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py deleted file mode 100644 index 0a8253595e..0000000000 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ /dev/null @@ -1,67 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Any - -from gt4py.eve.codegen import FormatTemplate as as_fmt, TemplatedGenerator -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import apply_common_transforms -from gt4py.next.program_processors import program_formatter - - -class ToLispLike(TemplatedGenerator): - Sym = as_fmt("{id}") - FunCall = as_fmt("({fun} {' '.join(args)})") - Literal = as_fmt("{value}") - OffsetLiteral = as_fmt("{value}") - SymRef = as_fmt("{id}") - StencilClosure = as_fmt( - """( - :domain {domain} - :stencil {stencil} - :output {output} - :inputs {' '.join(inputs)} - ) - """ - ) - FencilDefinition = as_fmt( - """ - ({' '.join(function_definitions)}) - (defen {id}({' '.join(params)}) - {''.join(closures)}) - """ - ) - FunctionDefinition = as_fmt( - """(defun {id}({' '.join(params)}) - {expr} - ) - -""" - ) - Lambda = as_fmt( - """(lambda ({' '.join(params)}) - {expr} - )""" - ) - - @classmethod - def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) - generated_code = super().apply(transformed, **kwargs) - try: - from yasi import indent_code - - indented = indent_code(generated_code, "--dialect lisp") - return "".join(indented["indented_code"]) - except ImportError: - return generated_code - - -@program_formatter.program_formatter -def format_lisp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return ToLispLike.apply(program, **kwargs) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index f14ac5653f..cbf9fd1978 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_itir_and_check(program: itir.Program, *args: Any, **kwargs: Any) -> str: pretty = pretty_printer.pformat(program) parsed = pretty_parser.pparse(pretty) assert parsed == program diff --git a/src/gt4py/next/program_processors/program_formatter.py b/src/gt4py/next/program_processors/program_formatter.py index f77e7f32ee..321c09668c 100644 --- a/src/gt4py/next/program_processors/program_formatter.py +++ b/src/gt4py/next/program_processors/program_formatter.py @@ -10,7 +10,7 @@ Interface for program processors. Program processors are functions which operate on a program paired with the input -arguments for the program. Programs are represented by an ``iterator.ir.itir.FencilDefinition`` +arguments for the program. Programs are represented by an ``iterator.ir.Program`` node. Program processors that execute the program with the given arguments (possibly by generating code along the way) are program executors. Those that generate any kind of string based on the program and (optionally) input values are program formatters. @@ -30,14 +30,14 @@ class ProgramFormatter: @abc.abstractmethod - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: ... + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: ... @dataclasses.dataclass(frozen=True) class WrappedProgramFormatter(ProgramFormatter): formatter: Callable[..., str] - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: return self.formatter(program, *args, **kwargs) @@ -47,7 +47,7 @@ def program_formatter(func: Callable[..., str]) -> ProgramFormatter: Examples: >>> @program_formatter - ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: + ... def format_foo(fencil: itir.Program, *args, **kwargs) -> str: ... '''A very useless fencil formatter.''' ... return "foo" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 40d44f5ab0..a38a50d886 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -72,7 +72,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data assert isinstance(program, itir.Program) sdfg = self.generate_sdfg( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 55f479c665..c0a9be9168 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -125,7 +125,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: Generates a unique hash string for a stencil source program representing the program, sorted offset_provider, and column_axis. """ - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 25eda5a2ed..32c3f7a360 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -90,11 +90,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, - transforms: itir_transforms.ITIRTransform, + transforms: itir_transforms.GTIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -197,7 +197,7 @@ class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgr debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None - transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -265,10 +265,10 @@ def decorated_fencil( gtir = next_backend.Backend( name="roundtrip_gtir", - executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later... + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # don't understand why mypy complains allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + past_to_itir=past_to_itir.past_to_gtir_factory(), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), field_view_op_to_prog=foast_to_past.operator_to_program_factory( foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 45bf7428a6..9e80dba53b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -21,7 +21,7 @@ ) -def test_program_itir_regression(cartesian_case): +def test_program_gtir_regression(cartesian_case): @gtx.field_operator(backend=None) def testee_op(a: cases.IField) -> cases.IField: return a @@ -30,8 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.Program) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.Program) + assert isinstance(testee.gtir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).gtir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 66c56c4827..7d2eec772c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -107,12 +107,12 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - itir_with_tmp = apply_common_transforms( - testee.itir, + gtir_with_tmp = apply_common_transforms( + testee.gtir, extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) params = ["num_vertices", "num_edges", "num_cells"] for param in params: - assert any([param == str(p) for p in itir_with_tmp.params]) + assert any([param == str(p) for p in gtir_with_tmp.params]) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8f6d5787d3..03662f8dcc 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -58,7 +58,6 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation - (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), ], diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py deleted file mode 100644 index c102df9d57..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ /dev/null @@ -1,598 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the -# lowering can (and often does) make all of them fail. Once we have embedded field view we want to -# switch to executing the different cases here; once with a regular backend (i.e. including -# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering -# should be correct. - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation -from gt4py.next.iterator.type_system import type_specifications as it_ts - - -IDim = gtx.Dimension("IDim") -Edge = gtx.Dimension("Edge") -Vertex = gtx.Dimension("Vertex") -Cell = gtx.Dimension("Cell") -V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. - - -def debug_itir(tree): - """Compare tree snippets while debugging.""" - from devtools import debug - - from gt4py.eve.codegen import format_python_source - from gt4py.next.program_processors import EmbeddedDSL - - debug(format_python_source(EmbeddedDSL.apply(tree))) - - -def test_copy(): - def copy_field(inp: gtx.Field[[TDim], float64]): - return inp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - assert lowered.id == "copy_field" - assert lowered.expr == im.ref("inp") - - -def test_scalar_arg(): - def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: - return alpha * bar - - parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")( - "alpha", "bar" - ) # no difference to non-scalar arg - - assert lowered.expr == reference - - -def test_multicopy(): - def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1, inp2 - - parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple("inp1", "inp2") - - assert lowered.expr == reference - - -def test_arithmetic(): - def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1 + inp2 - - parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") - - assert lowered.expr == reference - - -def test_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_negative_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[-1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_temp_assignment(): - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp = inp - inp = tmp - tmp2 = inp - return tmp2 - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")( - im.let( - ssa.unique_name("inp", 0), - ssa.unique_name("tmp", 0), - )( - im.let( - ssa.unique_name("tmp2", 0), - ssa.unique_name("inp", 0), - )(ssa.unique_name("tmp2", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_unary_ops(): - def unary(inp: gtx.Field[[TDim], float64]): - tmp = +inp - tmp = -tmp - return tmp - - parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("0", "float64")), "inp" - ), - )( - im.let( - ssa.unique_name("tmp", 1), - im.promote_to_lifted_stencil("minus")( - im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) - ), - )(ssa.unique_name("tmp", 1)) - ) - - assert lowered.expr == reference - - -@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) -def test_unary_op_type_conversion(var, var_type): - def unary_float(): - return float(-1) - - def unary_bool(): - return bool(-1) - - fun = unary_bool if var_type == "bool" else unary_float - parsed = FieldOperatorParser.apply_to_function(fun) - lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_const_iterator(im.literal(var, var_type)) - - assert lowered.expr == reference - - -def test_unpacking(): - """Unpacking assigns should get separated.""" - - def unpacking( - inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] - ) -> gtx.Field[[TDim], float64]: - tmp1, tmp2 = inp1, inp2 # noqa - return tmp1 - - parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("inp1", "inp2") - tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") - tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") - - reference = im.let("__tuple_tmp_0", tuple_expr)( - im.let( - ssa.unique_name("tmp1", 0), - tuple_access_0, - )( - im.let( - ssa.unique_name("tmp2", 0), - tuple_access_1, - )(ssa.unique_name("tmp1", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_annotated_assignment(): - pytest.xfail("Annotated assignments are not properly supported at the moment.") - - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp: gtx.Field[[TDim], float64] = inp - return tmp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_call(): - # create something that appears to the lowering like a field operator. - # we could also create an actual field operator, but we want to avoid - # using such heavy constructs for testing the lowering. - field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) - identity = SimpleNamespace( - __gt_type__=lambda: ts_ffront.FieldOperatorType( - definition=ts.FunctionType( - pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type - ) - ) - ) - - def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: - return identity(inp) - - parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.call("identity")("inp") - - assert lowered.expr == reference - - -def test_temp_tuple(): - """Returning a temp tuple should work.""" - - def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("a", "b") - reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_unary_not(): - def unary_not(cond: gtx.Field[[TDim], "bool"]): - return not cond - - parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("not_")("cond") - - assert lowered.expr == reference - - -def test_binary_plus(): - def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a + b - - parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("a", "b") - - assert lowered.expr == reference - - -def test_add_scalar_literal_to_field(): - def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: - return 2.0 + a - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" - ) - - assert lowered.expr == reference - - -def test_add_scalar_literals(): - def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: - tmp = int32(1) + int32("1") - return a + tmp - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - ), - )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) - - assert lowered.expr == reference - - -def test_binary_mult(): - def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a * b - - parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")("a", "b") - - assert lowered.expr == reference - - -def test_binary_minus(): - def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a - b - - parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("minus")("a", "b") - - assert lowered.expr == reference - - -def test_binary_div(): - def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a / b - - parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("divides")("a", "b") - - assert lowered.expr == reference - - -def test_binary_and(): - def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a & b - - parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")("a", "b") - - assert lowered.expr == reference - - -def test_scalar_and(): - def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: - return a & False - - parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - "a", im.promote_to_const_iterator(im.literal("False", "bool")) - ) - - assert lowered.expr == reference - - -def test_binary_or(): - def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a | b - - parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("or_")("a", "b") - - assert lowered.expr == reference - - -def test_compare_scalars(): - def comp_scalars() -> bool: - return 3 > 4 - - parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")( - im.promote_to_const_iterator(im.literal("3", "int32")), - im.promote_to_const_iterator(im.literal("4", "int32")), - ) - - assert lowered.expr == reference - - -def test_compare_gt(): - def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a > b - - parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")("a", "b") - - assert lowered.expr == reference - - -def test_compare_lt(): - def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a < b - - parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("less")("a", "b") - - assert lowered.expr == reference - - -def test_compare_eq(): - def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): - return a == b - - parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("eq")("a", "b") - - assert lowered.expr == reference - - -def test_compare_chain(): - def compare_chain( - a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] - ) -> gtx.Field[[IDim], bool]: - return a > b > c - - parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - im.promote_to_lifted_stencil("greater")("a", "b"), - im.promote_to_lifted_stencil("greater")("b", "c"), - ) - - assert lowered.expr == reference - - -def test_reduction_lowering_simple(): - def reduction(edge_f: gtx.Field[[Edge], float64]): - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), - ) - ) - )(im.lifted_neighbors("V2E", "edge_f")) - - assert lowered.expr == reference - - -def test_reduction_lowering_expr(): - def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): - e1_nbh = e1(V2E) - return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( - im.promote_to_lifted_stencil("make_const_list")( - im.promote_to_const_iterator(im.literal("1.1", "float64")) - ), - im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), - ) - - reference = im.let( - ssa.unique_name("e1_nbh", 0), - im.lifted_neighbors("V2E", "e1"), - )( - im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref( - im.promote_to_const_iterator(im.literal(value="0", typename="float64")) - ), - ) - ) - )(mapped) - ) - - assert lowered.expr == reference - - -def test_builtin_int_constructors(): - def int_constrs() -> tuple[int32, int32, int64, int32, int64]: - return 1, int32(1), int64(1), int32("1"), int64("1") - - parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - ) - - assert lowered.expr == reference - - -def test_builtin_float_constructors(): - def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: - return ( - 0.1, - float(0.1), - float32(0.1), - float64(0.1), - float(".1"), - float32(".1"), - float64(".1"), - ) - - parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - ) - - assert lowered.expr == reference - - -def test_builtin_bool_constructors(): - def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: - return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") - - parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), - ) - - assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index a6231c22a7..c813285bd0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -46,7 +46,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -93,7 +92,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -149,9 +147,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2[1:])) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail( @@ -166,9 +162,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2)) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail @@ -194,7 +188,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): ProgramParser.apply_to_function(invalid_call_sig_program_def), function_definitions=[], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) assert exc_info.match("Invalid call to 'identity'") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py deleted file mode 100644 index fefd3c653b..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ /dev/null @@ -1,214 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import re - -import pytest - -import gt4py.eve as eve -import gt4py.next as gtx -from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import errors -from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir -from gt4py.next.type_system import type_specifications as ts - -from next_tests.past_common_fixtures import ( - IDim, - copy_program_def, - copy_restrict_program_def, - float64, - identity_def, - invalid_call_sig_program_def, -) - - -@pytest.fixture -def itir_identity_fundef(): - return itir.FunctionDefinition( - id="identity", - params=[itir.Sym(id="x")], - expr=itir.FunCall(fun=itir.SymRef(id="deref"), args=[itir.SymRef(id="x")]), - ) - - -def test_copy_lowering(copy_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), - ], - ) - ], - ), - stencil=P( - itir.Lambda, - params=[P(itir.Sym, id=eve.SymbolName("__stencil_arg0"))], - expr=P( - itir.FunCall, - fun=P( - itir.Lambda, - params=[P(itir.Sym)], - expr=P(itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("deref"))), - ), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("identity")), - args=[P(itir.SymRef, id=eve.SymbolRef("__stencil_arg0"))], - ) - ], - ), - ), - inputs=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], - output=P(itir.SymRef, id=eve.SymbolRef("out")), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_restrict_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - ], - ) - ], - ), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_restrict_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2[1:])) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail( - reason="slicing is only allowed if all fields are sliced in the same way." -) # see ADR 10 -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2)) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail -def test_inout_prohibited(identity_def): - identity = gtx.field_operator(identity_def) - - def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): - identity(inout_field, out=inout_field) - - with pytest.raises( - ValueError, match=(r"Call to function with field as input and output not allowed.") - ): - ProgramLowering.apply( - ProgramParser.apply_to_function(inout_field_program), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - -def test_invalid_call_sig_program(invalid_call_sig_program_def): - with pytest.raises(errors.DSLError) as exc_info: - ProgramLowering.apply( - ProgramParser.apply_to_function(invalid_call_sig_program_def), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - assert exc_info.match("Invalid call to 'identity'") - # TODO(tehrengruber): re-enable again when call signature check doesn't return - # immediately after missing `out` argument - # assert ( - # re.search( - # "Function takes 1 arguments, but 2 were given.", exc_info.value.__cause__.args[0] - # ) - # is not None - # ) - assert ( - re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) - is not None - ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 817c06e8f0..2492fc446d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,21 +8,24 @@ # TODO(SF-N): test scan operator -import pytest +from typing import Iterable, Literal, Optional, Union + import numpy as np -from typing import Iterable, Optional, Literal, Union +import pytest from gt4py import eve -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next import constructors +from gt4py.next import common, constructors, utils +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.ir_utils import domain_utils -from gt4py.next.common import Dimension -from gt4py.next import common -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next import utils +from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py deleted file mode 100644 index 407ccad924..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py +++ /dev/null @@ -1,68 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs - - -def test_simple(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected - - -def test_shadowing(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index b1ba4ccf22..03b8e3bc15 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -36,6 +36,7 @@ from . import pytestmark + dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace_fieldview")