diff --git a/src/gt4py/next/iterator/transforms/extractors.py b/src/gt4py/next/iterator/transforms/extractors.py new file mode 100644 index 0000000000..68f74970eb --- /dev/null +++ b/src/gt4py/next/iterator/transforms/extractors.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts + + +class SymbolNameSetExtractor(eve.NodeVisitor): + """Extract a set of symbol names""" + + def generic_visitor(self, node: itir.Node) -> set[str]: + input_fields: set[str] = set() + for child in eve.trees.iter_children_values(node): + input_fields |= self.visit(child) + return input_fields + + @classmethod + def only_fields(cls, program: itir.Program) -> set[str]: + field_param_names = [ + str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) + ] + return {name for name in cls().visit(program) if name in field_param_names} + + +class InputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names passed into field operators within a program.""" + + def visit_Program(self, node: itir.Program) -> set[str]: + input_fields = set() + for stmt in node.body: + input_fields |= self.visit(stmt) + return input_fields + + def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: + input_fields = set() + for stmt in node.true_branch + node.false_branch: + input_fields |= self.visit(stmt) + return input_fields + + def visit_Temporary(self, node: itir.Temporary) -> set[str]: + return set() + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.expr) + + def visit_FunCall(self, node: itir.FunCall) -> set[str]: + input_fields = set() + for arg in node.args: + input_fields |= self.visit(arg) + return input_fields + + def visit_SymRef(self, node: itir.SymRef) -> set[str]: + return {str(node.id)} + + +class OutputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names written to within a program""" + + def visit_Program(self, node: itir.Program) -> set[str]: + output_fields = set() + for stmt in node.body: + output_fields |= self.visit(stmt) + return output_fields + + def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: + output_fields = set() + for stmt in node.true_branch + node.false_branch: + output_fields |= self.visit(stmt) + return output_fields + + def visit_Temporary(self, node: itir.Temporary) -> set[str]: + return set() + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.target) + + def visit_SymRef(self, node: itir.SymRef) -> set[str]: + return {str(node.id)} diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py index ea04a430b9..803ae866fb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py @@ -15,10 +15,10 @@ import dace import numpy as np -from gt4py import eve from gt4py.next import backend as next_backend, common from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import extractors as extractors from gt4py.next.otf import arguments, recipes, toolchain from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.type_system import type_specifications as ts @@ -96,10 +96,14 @@ def single_horizontal_dim_per_field( if len(horizontal_dims) == 1: yield str(field.id), horizontal_dims[0] - input_fields = (field_params[name] for name in InputNamesExtractor.only_fields(program)) + input_fields = ( + field_params[name] for name in extractors.InputNamesExtractor.only_fields(program) + ) sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields)) - output_fields = (field_params[name] for name in OutputNamesExtractor.only_fields(program)) + output_fields = ( + field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program) + ) sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) # TODO (ricoh): bring back sdfg.offset_providers_per_input_field. @@ -191,79 +195,6 @@ def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: return [p.id for p in self.past_stage.past_node.params], [] -class SymbolNameSetExtractor(eve.NodeVisitor): - """Extract a set of symbol names""" - - def generic_visitor(self, node: itir.Node) -> set[str]: - input_fields: set[str] = set() - for child in eve.trees.iter_children_values(node): - input_fields |= self.visit(child) - return input_fields - - @classmethod - def only_fields(cls, program: itir.Program) -> set[str]: - field_param_names = [ - str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) - ] - return {name for name in cls().visit(program) if name in field_param_names} - - -class InputNamesExtractor(SymbolNameSetExtractor): - """Extract the set of symbol names passed into field operators within a program.""" - - def visit_Program(self, node: itir.Program) -> set[str]: - input_fields = set() - for stmt in node.body: - input_fields |= self.visit(stmt) - return input_fields - - def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: - input_fields = set() - for stmt in node.true_branch + node.false_branch: - input_fields |= self.visit(stmt) - return input_fields - - def visit_Temporary(self, node: itir.Temporary) -> set[str]: - return set() - - def visit_SetAt(self, node: itir.SetAt) -> set[str]: - return self.visit(node.expr) - - def visit_FunCall(self, node: itir.FunCall) -> set[str]: - input_fields = set() - for arg in node.args: - input_fields |= self.visit(arg) - return input_fields - - def visit_SymRef(self, node: itir.SymRef) -> set[str]: - return {str(node.id)} - - -class OutputNamesExtractor(SymbolNameSetExtractor): - """Extract the set of symbol names written to within a program""" - - def visit_Program(self, node: itir.Program) -> set[str]: - output_fields = set() - for stmt in node.body: - output_fields |= self.visit(stmt) - return output_fields - - def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: - output_fields = set() - for stmt in node.true_branch + node.false_branch: - output_fields |= self.visit(stmt) - return output_fields - - def visit_Temporary(self, node: itir.Temporary) -> set[str]: - return set() - - def visit_SetAt(self, node: itir.SetAt) -> set[str]: - return self.visit(node.target) - - def visit_SymRef(self, node: itir.SymRef) -> set[str]: - return {str(node.id)} - - def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None: for dace_parsed_arg, gt4py_program_arg in zip( dace_parsed_args, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index ba24e1f7df..2a3946a77e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import warnings from collections.abc import Callable, Sequence from inspect import currentframe, getframeinfo from pathlib import Path @@ -120,13 +119,7 @@ def build_sdfg_from_itir( for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: - _, frameinfo = ( - warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.", - stacklevel=2, - ), - getframeinfo(currentframe()), # type: ignore[arg-type] - ) + frameinfo = getframeinfo(currentframe()) # type: ignore[arg-type] nested_sdfg.debuginfo = dace.dtypes.DebugInfo( start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename ) diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 8b8f206ef7..46a908e6f6 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -10,7 +10,6 @@ from gt4py import next as gtx from gt4py.next import common -from gt4py.next.program_processors.runners.dace_fieldview import program as dace_prg from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -20,16 +19,18 @@ JDim, KDim, Vertex, - mesh_descriptor, + mesh_descriptor, # noqa: F401 ) try: import dace + from gt4py.next.program_processors.runners.dace import gtir_cpu, gtir_gpu except ImportError: - from typing import Optional from types import ModuleType + from typing import Optional + from gt4py.next import backend as next_backend dace: Optional[ModuleType] = None @@ -66,7 +67,7 @@ def cartesian(request, gtir_dace_backend): @pytest.fixture -def unstructured(request, gtir_dace_backend, mesh_descriptor): +def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 if gtir_dace_backend is None: yield None @@ -84,48 +85,6 @@ def unstructured(request, gtir_dace_backend, mesh_descriptor): ) -@pytest.mark.skipif(dace is None, reason="DaCe not found") -def test_input_names_extractor_cartesian(cartesian): - @gtx.field_operator(backend=cartesian.backend) - def testee_op( - a: gtx.Field[[IDim, JDim, KDim], gtx.int], - ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: - return a - - @gtx.program(backend=cartesian.backend) - def testee( - a: gtx.Field[[IDim, JDim, KDim], gtx.int], - b: gtx.Field[[IDim, JDim, KDim], gtx.int], - c: gtx.Field[[IDim, JDim, KDim], gtx.int], - ): - testee_op(b, out=c) - testee_op(a, out=b) - - input_field_names = dace_prg.InputNamesExtractor.only_fields(testee.itir) - assert input_field_names == {"a", "b"} - - -@pytest.mark.skipif(dace is None, reason="DaCe not found") -def test_output_names_extractor(cartesian): - @gtx.field_operator(backend=cartesian.backend) - def testee_op( - a: gtx.Field[[IDim, JDim, KDim], gtx.int], - ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: - return a - - @gtx.program(backend=cartesian.backend) - def testee( - a: gtx.Field[[IDim, JDim, KDim], gtx.int], - b: gtx.Field[[IDim, JDim, KDim], gtx.int], - c: gtx.Field[[IDim, JDim, KDim], gtx.int], - ): - testee_op(a, out=b) - testee_op(a, out=c) - - output_field_names = dace_prg.OutputNamesExtractor.only_fields(testee.itir) - assert output_field_names == {"b", "c"} - - @pytest.mark.skipif(dace is None, reason="DaCe not found") def test_halo_exchange_helper_attrs(unstructured): @gtx.field_operator(backend=unstructured.backend) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py new file mode 100644 index 0000000000..48d7b04dda --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py @@ -0,0 +1,103 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import typing + +import pytest + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import extractors + +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + IDim, + JDim, + KDim, +) + + +if typing.TYPE_CHECKING: + from types import ModuleType + from typing import Optional + +try: + import dace + + from gt4py.next.program_processors.runners.dace import gtir_cpu, gtir_gpu +except ImportError: + from gt4py.next import backend as next_backend + + dace: Optional[ModuleType] = None + gtir_cpu: Optional[next_backend.Backend] = None + gtir_gpu: Optional[next_backend.Backend] = None + + +@pytest.fixture(params=[pytest.param(gtir_cpu, marks=pytest.mark.requires_dace), gtx.gtfn_cpu]) +def gtir_dace_backend(request): + yield request.param + + +@pytest.fixture +def cartesian(request, gtir_dace_backend): + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider={ + "Ioff": IDim, + "Joff": JDim, + "Koff": KDim, + }, + default_sizes={IDim: 10, JDim: 10, KDim: 10}, + grid_type=common.GridType.CARTESIAN, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_input_names_extractor_cartesian(cartesian): + @gtx.field_operator(backend=cartesian.backend) + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program(backend=cartesian.backend) + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + input_field_names = extractors.InputNamesExtractor.only_fields(testee.itir) + assert input_field_names == {"a", "b"} + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_output_names_extractor(cartesian): + @gtx.field_operator(backend=cartesian.backend) + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program(backend=cartesian.backend) + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(a, out=b) + testee_op(a, out=c) + + output_field_names = extractors.OutputNamesExtractor.only_fields(testee.itir) + assert output_field_names == {"b", "c"}