From 62b23ccd3210fbd42ef9a85b645a0aa60bbedabb Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 14:28:21 +0100 Subject: [PATCH 01/11] Fix cast, tests still wip --- src/gt4py/next/ffront/foast_to_gtir.py | 6 +----- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 ++ tests/next_tests/integration_tests/cases.py | 1 + .../feature_tests/ffront_tests/test_execution.py | 13 +++++++++++++ .../unit_tests/ffront_tests/test_foast_to_gtir.py | 12 ++++++++++++ 5 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2c2971f49a..3c65695aec 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -359,11 +359,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: - if isinstance(t[0], ts.FieldType): - return im.cast_as_fieldop(str(new_type))(expr) - else: - assert isinstance(t[0], ts.ScalarType) - return im.call("cast_")(expr, str(new_type)) + return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj, (node.args[0].type,)) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index a4e111e785..f42ff799af 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -485,6 +485,8 @@ def op_as_fieldop( >>> str(op_as_fieldop("op")("a", "b")) '(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)' """ + # im.as_fieldop(im.lambda_("__arg0")(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32")))(im.deref("__arg0"))))("a") + # (⇑(λ(__arg0) → map_(λ(val) → cast_(val, int32))(·__arg0)))(a) if isinstance(op, (str, itir.SymRef, itir.Lambda)): op = call(op) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 9fb7850666..2489269d7b 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -69,6 +69,7 @@ IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type] IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type] VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type] +VFloatField: TypeAlias = gtx.Field[[Vertex], np.float64] # type: ignore [valid-type] VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type] EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type] CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1a51e3667d..e520c120a7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -423,6 +423,19 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +def test_astype_int_sparse(unstructured_case): + @gtx.field_operator + def testee(a: cases.VFloatField) -> gtx.Field[[Edge, E2VDim], int64]: + return astype(a(E2V), int64) + + cases.verify_with_default_data( + unstructured_case, + testee, + ref=lambda a_: a_.astype(int64), + comparison=lambda a_, b_: np.all(a_ == b_), + ) + + @pytest.mark.uses_tuple_returns def test_astype_on_tuples(cartesian_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 516890ea46..3778a343c4 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -289,6 +289,18 @@ def foo(a: gtx.Field[[TDim], float64]): assert lowered.expr == reference +def test_astype_LocalDim(): + def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a") + + assert lowered.expr == reference + + def test_astype_scalar(): def foo(a: float64): return astype(a, int32) From 1512be578419e9d263cbfc5a9e8d1f0bd4bdb48d Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 15:02:29 +0100 Subject: [PATCH 02/11] Complete tests --- .../feature_tests/ffront_tests/test_execution.py | 13 ++++++++----- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index e520c120a7..68e01ab11a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -423,16 +423,19 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) -def test_astype_int_sparse(unstructured_case): +def test_astype_int_local_field(unstructured_case): @gtx.field_operator - def testee(a: cases.VFloatField) -> gtx.Field[[Edge, E2VDim], int64]: - return astype(a(E2V), int64) + def testee(a: cases.VFloatField) -> gtx.Field[[Edge], int64]: + tmp = astype(a(E2V), int64) + return neighbor_sum(tmp, axis=E2VDim) + + e2v_table = unstructured_case.offset_provider["E2V"].ndarray cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a_: a_.astype(int64), - comparison=lambda a_, b_: np.all(a_ == b_), + ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0), + comparison=lambda a, b: np.all(a == b), ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 3778a343c4..bb64af8da9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -289,7 +289,7 @@ def foo(a: gtx.Field[[TDim], float64]): assert lowered.expr == reference -def test_astype_LocalDim(): +def test_astype_local_field(): def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): return astype(a, int32) From e9c90b207aaaf9db6575507da2b6d7cb5a16878c Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 15:24:24 +0100 Subject: [PATCH 03/11] Run pre-commit --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f42ff799af..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -485,8 +485,6 @@ def op_as_fieldop( >>> str(op_as_fieldop("op")("a", "b")) '(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)' """ - # im.as_fieldop(im.lambda_("__arg0")(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32")))(im.deref("__arg0"))))("a") - # (⇑(λ(__arg0) → map_(λ(val) → cast_(val, int32))(·__arg0)))(a) if isinstance(op, (str, itir.SymRef, itir.Lambda)): op = call(op) From 82a69c2b1a5494660d8c9e94baa447f3b0aa71b4 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 15:27:48 +0100 Subject: [PATCH 04/11] Minor --- tests/next_tests/integration_tests/cases.py | 1 - .../feature_tests/ffront_tests/test_execution.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 2489269d7b..9fb7850666 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -69,7 +69,6 @@ IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type] IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type] VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type] -VFloatField: TypeAlias = gtx.Field[[Vertex], np.float64] # type: ignore [valid-type] VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type] EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type] CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 68e01ab11a..e1f7e6fb4a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -425,7 +425,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: def test_astype_int_local_field(unstructured_case): @gtx.field_operator - def testee(a: cases.VFloatField) -> gtx.Field[[Edge], int64]: + def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: tmp = astype(a(E2V), int64) return neighbor_sum(tmp, axis=E2VDim) From eaee70dd4d247e855c932a63014641d688758cf5 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 15:55:04 +0100 Subject: [PATCH 05/11] Skipping dace tests --- .../feature_tests/ffront_tests/test_execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index e1f7e6fb4a..ba6b184627 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -424,6 +424,9 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: def test_astype_int_local_field(unstructured_case): + if unstructured_case.backend and "dace" in unstructured_case.backend.name: + pytest.skip("Skipping dace: dace_itir: deprecated soon, dace_gtir: feature missing") + @gtx.field_operator def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: tmp = astype(a(E2V), int64) From 562062806d0986561b146f24f0e4444c9402ec73 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 29 Nov 2024 16:47:05 +0100 Subject: [PATCH 06/11] fix dace_gtir --- .../dace_fieldview/gtir_python_codegen.py | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 6aee33c56e..9148284398 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any, Callable, Sequence import numpy as np @@ -118,29 +118,47 @@ class PythonCodegen(codegen.TemplatedGenerator): as in the case of field domain definitions, for sybolic array shape and map range. """ - SymRef = as_fmt("{id}") Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall) -> str: + def _visit_deref(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str: assert len(node.args) == 1 if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0]) + return self.visit(node.args[0], symbol_mapping=symbol_mapping) raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - def visit_FunCall(self, node: gtir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) + def _visit_lambda( + self, + node: gtir.Lambda, + node_args: Sequence[gtir.Node], + symbol_mapping: dict[str, gtir.Node], + ) -> str: + symbol_mapping |= {param.id: arg for param, arg in zip(node.params, node_args)} + return self.visit(node.expr, symbol_mapping=symbol_mapping) + + def visit_FunCall(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str: + if isinstance(node.fun, gtir.Lambda): + return self._visit_lambda(node.fun, node.args, symbol_mapping=symbol_mapping) + elif cpm.is_call_to(node, "deref"): + return self._visit_deref(node, symbol_mapping=symbol_mapping) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args) + args = self.visit(node.args, symbol_mapping=symbol_mapping) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_SymRef(self, node: gtir.SymRef, symbol_mapping: dict[str, gtir.Node]) -> str: + symbol = str(node.id) + if symbol_mapping and symbol in symbol_mapping: + mapped_node = symbol_mapping[symbol] + return self.visit(mapped_node, symbol_mapping=symbol_mapping) + return symbol -get_source = PythonCodegen.apply -""" -Specialized visit method for symbolic expressions. -Returns: - A string containing the Python code corresponding to a symbolic expression -""" +def get_source(node: gtir.Node) -> str: + """ + Specialized visit method for symbolic expressions. + + Returns: + A string containing the Python code corresponding to a symbolic expression + """ + return PythonCodegen.apply(node, symbol_mapping={}) From bada5b6b131eeb580bc2392af8fbe0943ab00cd3 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 17:16:12 +0100 Subject: [PATCH 07/11] Don't skip dace_gtir --- .../feature_tests/ffront_tests/test_execution.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index ba6b184627..25d260c137 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -424,8 +424,12 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: def test_astype_int_local_field(unstructured_case): - if unstructured_case.backend and "dace" in unstructured_case.backend.name: - pytest.skip("Skipping dace: dace_itir: deprecated soon, dace_gtir: feature missing") + if ( + unstructured_case.backend + and "dace" in unstructured_case.backend.name + and "itir" in unstructured_case.backend.name + ): + pytest.skip("Skipping dace_itir: deprecated soon") @gtx.field_operator def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: From 6751a07f13cecd4ec4e802045ace8fc3409dccf5 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Nov 2024 18:13:47 +0100 Subject: [PATCH 08/11] inline lambdas in tests --- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index bb64af8da9..59a8dc961b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -283,10 +283,11 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.cast_as_fieldop("int32")("a") - assert lowered.expr == reference + assert lowered_inlined.expr == reference def test_astype_local_field(): @@ -307,10 +308,11 @@ def foo(a: float64): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.call("cast_")("a", "int32") - assert lowered.expr == reference + assert lowered_inlined.expr == reference def test_astype_tuple(): From 096c5c64d5f3932f575f6df7c70879f5458c7e7f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 2 Dec 2024 09:41:56 +0100 Subject: [PATCH 09/11] minor edit --- .../dace_fieldview/gtir_python_codegen.py | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 9148284398..95b7ce5213 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Callable, Sequence +from typing import Any, Callable import numpy as np @@ -120,37 +120,27 @@ class PythonCodegen(codegen.TemplatedGenerator): Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0], symbol_mapping=symbol_mapping) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def _visit_lambda( - self, - node: gtir.Lambda, - node_args: Sequence[gtir.Node], - symbol_mapping: dict[str, gtir.Node], - ) -> str: - symbol_mapping |= {param.id: arg for param, arg in zip(node.params, node_args)} - return self.visit(node.expr, symbol_mapping=symbol_mapping) - - def visit_FunCall(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str: + def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: if isinstance(node.fun, gtir.Lambda): - return self._visit_lambda(node.fun, node.args, symbol_mapping=symbol_mapping) + # update the mapping from lambda parameters to corresponding argument expressions + args_map |= {p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True)} + return self.visit(node.fun.expr, args_map=args_map) elif cpm.is_call_to(node, "deref"): - return self._visit_deref(node, symbol_mapping=symbol_mapping) + assert len(node.args) == 1 + if not isinstance(node.args[0], gtir.SymRef): + # shift expressions are not expected in this visitor context + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + return self.visit(node.args[0], args_map=args_map) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args, symbol_mapping=symbol_mapping) + args = self.visit(node.args, args_map=args_map) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") - def visit_SymRef(self, node: gtir.SymRef, symbol_mapping: dict[str, gtir.Node]) -> str: + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: symbol = str(node.id) - if symbol_mapping and symbol in symbol_mapping: - mapped_node = symbol_mapping[symbol] - return self.visit(mapped_node, symbol_mapping=symbol_mapping) + if symbol in args_map: + return self.visit(args_map[symbol], args_map=args_map) return symbol @@ -158,7 +148,9 @@ def get_source(node: gtir.Node) -> str: """ Specialized visit method for symbolic expressions. + The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions. + Returns: A string containing the Python code corresponding to a symbolic expression """ - return PythonCodegen.apply(node, symbol_mapping={}) + return PythonCodegen.apply(node, args_map={}) From 70c64c69bc48fcd0c64723cba42a7d9a4ac94745 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 2 Dec 2024 10:37:19 +0100 Subject: [PATCH 10/11] fix previous commit --- .../runners/dace_fieldview/gtir_python_codegen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 95b7ce5213..4bdb602f5f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -123,8 +123,10 @@ class PythonCodegen(codegen.TemplatedGenerator): def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: if isinstance(node.fun, gtir.Lambda): # update the mapping from lambda parameters to corresponding argument expressions - args_map |= {p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True)} - return self.visit(node.fun.expr, args_map=args_map) + lambda_args_map = args_map | { + p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True) + } + return self.visit(node.fun.expr, args_map=lambda_args_map) elif cpm.is_call_to(node, "deref"): assert len(node.args) == 1 if not isinstance(node.args[0], gtir.SymRef): From 91985a13594b09a9c973670d6101bc7395e4ad97 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 3 Dec 2024 21:57:16 +0100 Subject: [PATCH 11/11] remove pytest.skip for dace itir because, no longer needed --- .../feature_tests/ffront_tests/test_execution.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index c2d1a05ec2..4eed7f5cde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -439,13 +439,6 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: def test_astype_int_local_field(unstructured_case): - if ( - unstructured_case.backend - and "dace" in unstructured_case.backend.name - and "itir" in unstructured_case.backend.name - ): - pytest.skip("Skipping dace_itir: deprecated soon") - @gtx.field_operator def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: tmp = astype(a(E2V), int64)