Skip to content

Commit

Permalink
refactor extractors and remove debuginfo warning
Browse files Browse the repository at this point in the history
  • Loading branch information
DropD committed Dec 3, 2024
1 parent 77c250b commit f8ec8f5
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 130 deletions.
84 changes: 84 additions & 0 deletions src/gt4py/next/iterator/transforms/extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py import eve
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_specifications as ts


class SymbolNameSetExtractor(eve.NodeVisitor):
"""Extract a set of symbol names"""

def generic_visitor(self, node: itir.Node) -> set[str]:
input_fields: set[str] = set()
for child in eve.trees.iter_children_values(node):
input_fields |= self.visit(child)
return input_fields

@classmethod
def only_fields(cls, program: itir.Program) -> set[str]:
field_param_names = [
str(param.id) for param in program.params if isinstance(param.type, ts.FieldType)
]
return {name for name in cls().visit(program) if name in field_param_names}


class InputNamesExtractor(SymbolNameSetExtractor):
"""Extract the set of symbol names passed into field operators within a program."""

def visit_Program(self, node: itir.Program) -> set[str]:
input_fields = set()
for stmt in node.body:
input_fields |= self.visit(stmt)
return input_fields

def visit_IfStmt(self, node: itir.IfStmt) -> set[str]:
input_fields = set()
for stmt in node.true_branch + node.false_branch:
input_fields |= self.visit(stmt)
return input_fields

def visit_Temporary(self, node: itir.Temporary) -> set[str]:
return set()

def visit_SetAt(self, node: itir.SetAt) -> set[str]:
return self.visit(node.expr)

def visit_FunCall(self, node: itir.FunCall) -> set[str]:
input_fields = set()
for arg in node.args:
input_fields |= self.visit(arg)
return input_fields

def visit_SymRef(self, node: itir.SymRef) -> set[str]:
return {str(node.id)}


class OutputNamesExtractor(SymbolNameSetExtractor):
"""Extract the set of symbol names written to within a program"""

def visit_Program(self, node: itir.Program) -> set[str]:
output_fields = set()
for stmt in node.body:
output_fields |= self.visit(stmt)
return output_fields

def visit_IfStmt(self, node: itir.IfStmt) -> set[str]:
output_fields = set()
for stmt in node.true_branch + node.false_branch:
output_fields |= self.visit(stmt)
return output_fields

def visit_Temporary(self, node: itir.Temporary) -> set[str]:
return set()

def visit_SetAt(self, node: itir.SetAt) -> set[str]:
return self.visit(node.target)

def visit_SymRef(self, node: itir.SymRef) -> set[str]:
return {str(node.id)}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import dace
import numpy as np

from gt4py import eve
from gt4py.next import backend as next_backend, common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import extractors as extractors
from gt4py.next.otf import arguments, recipes, toolchain
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.type_system import type_specifications as ts
Expand Down Expand Up @@ -96,10 +96,14 @@ def single_horizontal_dim_per_field(
if len(horizontal_dims) == 1:
yield str(field.id), horizontal_dims[0]

input_fields = (field_params[name] for name in InputNamesExtractor.only_fields(program))
input_fields = (
field_params[name] for name in extractors.InputNamesExtractor.only_fields(program)
)
sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields))

output_fields = (field_params[name] for name in OutputNamesExtractor.only_fields(program))
output_fields = (
field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program)
)
sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields))

# TODO (ricoh): bring back sdfg.offset_providers_per_input_field.
Expand Down Expand Up @@ -191,79 +195,6 @@ def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]:
return [p.id for p in self.past_stage.past_node.params], []


class SymbolNameSetExtractor(eve.NodeVisitor):
"""Extract a set of symbol names"""

def generic_visitor(self, node: itir.Node) -> set[str]:
input_fields: set[str] = set()
for child in eve.trees.iter_children_values(node):
input_fields |= self.visit(child)
return input_fields

@classmethod
def only_fields(cls, program: itir.Program) -> set[str]:
field_param_names = [
str(param.id) for param in program.params if isinstance(param.type, ts.FieldType)
]
return {name for name in cls().visit(program) if name in field_param_names}


class InputNamesExtractor(SymbolNameSetExtractor):
"""Extract the set of symbol names passed into field operators within a program."""

def visit_Program(self, node: itir.Program) -> set[str]:
input_fields = set()
for stmt in node.body:
input_fields |= self.visit(stmt)
return input_fields

def visit_IfStmt(self, node: itir.IfStmt) -> set[str]:
input_fields = set()
for stmt in node.true_branch + node.false_branch:
input_fields |= self.visit(stmt)
return input_fields

def visit_Temporary(self, node: itir.Temporary) -> set[str]:
return set()

def visit_SetAt(self, node: itir.SetAt) -> set[str]:
return self.visit(node.expr)

def visit_FunCall(self, node: itir.FunCall) -> set[str]:
input_fields = set()
for arg in node.args:
input_fields |= self.visit(arg)
return input_fields

def visit_SymRef(self, node: itir.SymRef) -> set[str]:
return {str(node.id)}


class OutputNamesExtractor(SymbolNameSetExtractor):
"""Extract the set of symbol names written to within a program"""

def visit_Program(self, node: itir.Program) -> set[str]:
output_fields = set()
for stmt in node.body:
output_fields |= self.visit(stmt)
return output_fields

def visit_IfStmt(self, node: itir.IfStmt) -> set[str]:
output_fields = set()
for stmt in node.true_branch + node.false_branch:
output_fields |= self.visit(stmt)
return output_fields

def visit_Temporary(self, node: itir.Temporary) -> set[str]:
return set()

def visit_SetAt(self, node: itir.SetAt) -> set[str]:
return self.visit(node.target)

def visit_SymRef(self, node: itir.SymRef) -> set[str]:
return {str(node.id)}


def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None:
for dace_parsed_arg, gt4py_program_arg in zip(
dace_parsed_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from collections.abc import Callable, Sequence
from inspect import currentframe, getframeinfo
from pathlib import Path
Expand Down Expand Up @@ -120,13 +119,7 @@ def build_sdfg_from_itir(

for nested_sdfg in sdfg.all_sdfgs_recursive():
if not nested_sdfg.debuginfo:
_, frameinfo = (
warnings.warn(
f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.",
stacklevel=2,
),
getframeinfo(currentframe()), # type: ignore[arg-type]
)
frameinfo = getframeinfo(currentframe()) # type: ignore[arg-type]
nested_sdfg.debuginfo = dace.dtypes.DebugInfo(
start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from gt4py import next as gtx
from gt4py.next import common
from gt4py.next.program_processors.runners.dace_fieldview import program as dace_prg

from next_tests.integration_tests import cases
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
Expand All @@ -20,16 +19,18 @@
JDim,
KDim,
Vertex,
mesh_descriptor,
mesh_descriptor, # noqa: F401
)


try:
import dace

from gt4py.next.program_processors.runners.dace import gtir_cpu, gtir_gpu
except ImportError:
from typing import Optional
from types import ModuleType
from typing import Optional

from gt4py.next import backend as next_backend

dace: Optional[ModuleType] = None
Expand Down Expand Up @@ -66,7 +67,7 @@ def cartesian(request, gtir_dace_backend):


@pytest.fixture
def unstructured(request, gtir_dace_backend, mesh_descriptor):
def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811
if gtir_dace_backend is None:
yield None

Expand All @@ -84,48 +85,6 @@ def unstructured(request, gtir_dace_backend, mesh_descriptor):
)


@pytest.mark.skipif(dace is None, reason="DaCe not found")
def test_input_names_extractor_cartesian(cartesian):
@gtx.field_operator(backend=cartesian.backend)
def testee_op(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
) -> gtx.Field[[IDim, JDim, KDim], gtx.int]:
return a

@gtx.program(backend=cartesian.backend)
def testee(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
b: gtx.Field[[IDim, JDim, KDim], gtx.int],
c: gtx.Field[[IDim, JDim, KDim], gtx.int],
):
testee_op(b, out=c)
testee_op(a, out=b)

input_field_names = dace_prg.InputNamesExtractor.only_fields(testee.itir)
assert input_field_names == {"a", "b"}


@pytest.mark.skipif(dace is None, reason="DaCe not found")
def test_output_names_extractor(cartesian):
@gtx.field_operator(backend=cartesian.backend)
def testee_op(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
) -> gtx.Field[[IDim, JDim, KDim], gtx.int]:
return a

@gtx.program(backend=cartesian.backend)
def testee(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
b: gtx.Field[[IDim, JDim, KDim], gtx.int],
c: gtx.Field[[IDim, JDim, KDim], gtx.int],
):
testee_op(a, out=b)
testee_op(a, out=c)

output_field_names = dace_prg.OutputNamesExtractor.only_fields(testee.itir)
assert output_field_names == {"b", "c"}


@pytest.mark.skipif(dace is None, reason="DaCe not found")
def test_halo_exchange_helper_attrs(unstructured):
@gtx.field_operator(backend=unstructured.backend)
Expand Down
Loading

0 comments on commit f8ec8f5

Please sign in to comment.