Skip to content

Commit

Permalink
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
Browse files Browse the repository at this point in the history
…tir, and from itir to the SDFG): Preserve Location through Visitors
  • Loading branch information
kotsaloscv committed Jan 12, 2024
1 parent 50f96a8 commit 9c9c8ae
Show file tree
Hide file tree
Showing 25 changed files with 64 additions and 93 deletions.
3 changes: 2 additions & 1 deletion src/gt4py/eve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
walk_values,
)
from .type_definitions import NOTHING, ConstrainedStr, Enum, IntEnum, NothingType, StrEnum
from .visitors import NodeTranslator, NodeVisitor
from .visitors import NodeTranslator, NodeVisitor, PreserveLocationVisitor


__all__ = [
Expand Down Expand Up @@ -132,4 +132,5 @@
# visitors
"NodeTranslator",
"NodeVisitor",
"PreserveLocationVisitor",
]
8 changes: 0 additions & 8 deletions src/gt4py/eve/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,3 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
kwargs["symtable"] = kwargs["symtable"].parents

return result


class PreserveLocationWithSymbolTableTrait(VisitorWithSymbolTableTrait):
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
result.location = node.location
return result
2 changes: 1 addition & 1 deletion src/gt4py/eve/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
return copy.deepcopy(node, memo=memo)


class PreserveLocation(NodeVisitor):
class PreserveLocationVisitor(NodeVisitor):
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
Expand Down
7 changes: 1 addition & 6 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,7 @@ def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.

def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any:
result = super().visit(node, **kwargs)
if (
hasattr(node, "location")
and hasattr(result, "location")
and not isinstance(node, foast.Name)
and self.preserve_location
):
if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location:
result.location = node.location
return result

Expand Down
15 changes: 6 additions & 9 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Optional, cast

from gt4py.eve import NodeTranslator, concepts, traits
from gt4py.eve import NodeTranslator, PreserveLocationVisitor, concepts, traits
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront
from gt4py.next.iterator import ir as itir
Expand All @@ -40,9 +40,7 @@ def _flatten_tuple_expr(
raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.")


class ProgramLowering(
traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
class ProgramLowering(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Lower Program AST (PAST) to Iterator IR (ITIR).
Expand Down Expand Up @@ -256,7 +254,9 @@ def _construct_itir_domain_arg(
raise AssertionError()

return itir.FunCall(
fun=itir.SymRef(id=domain_builtin), args=domain_args, location=out_field.location
fun=itir.SymRef(id=domain_builtin),
args=domain_args,
location=(node_domain or out_field).location,
)

def _construct_itir_initialized_domain_arg(
Expand All @@ -273,10 +273,7 @@ def _construct_itir_initialized_domain_arg(
f"expected '{dim}', got '{keys_dims_types}'."
)

itir_node = [self.visit(bound) for bound in node_domain.values_[dim_i].elts]
for i, bound in enumerate(node_domain.values_[dim_i].elts):
itir_node[i].location = bound.location
return itir_node
return [self.visit(bound) for bound in node_domain.values_[dim_i].elts]

@staticmethod
def _compute_field_slice(node: past.Subscript):
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py import eve
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import ir


class CollapseListGet(PreserveLocation, eve.NodeTranslator):
class CollapseListGet(PreserveLocationVisitor, eve.NodeTranslator):
"""Simplifies expressions containing `list_get`.
Examples
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional

from gt4py import eve
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference

Expand Down Expand Up @@ -49,7 +49,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t


@dataclass(frozen=True)
class CollapseTuple(PreserveLocation, eve.NodeTranslator):
class CollapseTuple(PreserveLocationVisitor, eve.NodeTranslator):
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im


class ConstantFolding(PreserveLocation, NodeTranslator):
class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Expand Down
18 changes: 5 additions & 13 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,15 @@
import operator
import typing

from gt4py.eve import (
NodeTranslator,
NodeVisitor,
SymbolTableTrait,
VisitorWithSymbolTableTrait,
traits,
)
from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda


@dataclasses.dataclass
class _NodeReplacer(PreserveLocation, NodeTranslator):
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
Expand Down Expand Up @@ -79,9 +73,7 @@ def _is_collectable_expr(node: ir.Node) -> bool:


@dataclasses.dataclass
class CollectSubexpressions(
traits.PreserveLocationWithSymbolTableTrait, VisitorWithSymbolTableTrait, NodeVisitor
):
class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor):
@dataclasses.dataclass
class SubexpressionData:
#: A list of node ids with equal hash and a set of collected child subexpression ids
Expand Down Expand Up @@ -350,7 +342,7 @@ def extract_subexpression(


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(PreserveLocation, NodeTranslator):
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/eta_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import ir


class EtaReduction(PreserveLocation, NodeTranslator):
class EtaReduction(PreserveLocationVisitor, NodeTranslator):
"""Eta reduction: simplifies `λ(args...) → f(args...)` to `f`."""

def visit_Lambda(self, node: ir.Lambda) -> ir.Node:
Expand Down
6 changes: 2 additions & 4 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
from typing import TypeGuard

from gt4py.eve import NodeTranslator, traits
from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import inline_lambdas
Expand All @@ -38,9 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:


@dataclasses.dataclass(frozen=True)
class FuseMaps(
traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
class FuseMaps(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Fuses nested `map_`s.
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from gt4py.eve import Coerced, NodeTranslator
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import ir, type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
Expand Down Expand Up @@ -570,7 +570,7 @@ def convert_type(dtype):
# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be
# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore
# and hence also not extract as a temporary.
class CreateGlobalTmps(PreserveLocation, NodeTranslator):
class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator):
"""Main entry point for introducing global temporaries.
Transforms an existing iterator IR fencil into a fencil with global temporaries.
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/inline_fundefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Any, Dict, Set

from gt4py.eve import NOTHING, NodeTranslator
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import ir


class InlineFundefs(PreserveLocation, NodeTranslator):
class InlineFundefs(PreserveLocationVisitor, NodeTranslator):
def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]):
if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition):
return ir.Lambda(
Expand All @@ -32,7 +32,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition):
return self.generic_visit(node, symtable=node.annex.symtable)


class PruneUnreferencedFundefs(PreserveLocation, NodeTranslator):
class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator):
def visit_FunctionDefinition(
self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool
):
Expand Down
6 changes: 2 additions & 4 deletions src/gt4py/next/iterator/transforms/inline_into_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Sequence, TypeGuard

from gt4py import eve
from gt4py.eve import NodeTranslator, traits
from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import symbol_ref_utils
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
Expand Down Expand Up @@ -53,9 +53,7 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall:
return inlined


class InlineIntoScan(
traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
class InlineIntoScan(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Inline non-SymRef arguments into the scan.
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional

from gt4py.eve import NodeTranslator
from gt4py.eve.visitors import PreserveLocation
from gt4py.eve.visitors import PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols
Expand Down Expand Up @@ -123,7 +123,7 @@ def new_name(name):


@dataclasses.dataclass
class InlineLambdas(PreserveLocation, NodeTranslator):
class InlineLambdas(PreserveLocationVisitor, NodeTranslator):
"""Inline lambda calls by substituting every argument by its value."""

PRESERVED_ANNEX_ATTRS = ("type",)
Expand Down
24 changes: 11 additions & 13 deletions src/gt4py/next/iterator/transforms/inline_lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Optional

import gt4py.eve as eve
from gt4py.eve import NodeTranslator, traits
from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
Expand All @@ -40,10 +40,10 @@ def _generate_unique_symbol(
else:
desired_name = f"__arg{arg_idx}"

new_symbol = ir.Sym(id=desired_name)
new_symbol = desired_name
# make unique
while new_symbol.id in occupied_names or new_symbol in occupied_symbols:
new_symbol = ir.Sym(id=new_symbol.id + "_")
new_symbol = new_symbol + "_"
return new_symbol


Expand Down Expand Up @@ -73,7 +73,7 @@ def _is_scan(node: ir.FunCall):
def _transform_and_extract_lift_args(
node: ir.FunCall,
symtable: dict[eve.SymbolName, ir.Sym],
extracted_args: dict[ir.Sym, ir.Expr],
extracted_args: dict[eve.SymbolName, ir.Expr],
):
"""
Transform and extract non-symbol arguments of a lifted stencil call.
Expand All @@ -89,8 +89,8 @@ def _transform_and_extract_lift_args(
new_args = []
for i, arg in enumerate(node.args):
if isinstance(arg, ir.SymRef):
sym = ir.Sym(id=arg.id)
assert sym not in extracted_args or extracted_args[sym] == arg
sym = arg.id
assert sym not in extracted_args or extracted_args[sym].id == arg.id
extracted_args[sym] = arg
new_args.append(arg)
else:
Expand All @@ -101,7 +101,7 @@ def _transform_and_extract_lift_args(
)
assert new_symbol not in extracted_args
extracted_args[new_symbol] = arg
new_args.append(ir.SymRef(id=new_symbol.id))
new_args.append(ir.SymRef(id=new_symbol))

itir_node = im.lift(inner_stencil)(*new_args)
itir_node.location = node.location
Expand All @@ -112,9 +112,7 @@ def _transform_and_extract_lift_args(
# passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without
# performance degradation we leave everything as one pass for now.
@dataclasses.dataclass
class InlineLifts(
traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
class InlineLifts(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""Inline lifted function calls.
Optionally a predicate function can be passed which can enable or disable inlining of specific
Expand Down Expand Up @@ -228,7 +226,7 @@ def visit_FunCall(
# TODO(tehrengruber): we currently only inlining opcount preserving, but what we
# actually want is to inline whenever the argument is not shifted. This is
# currently beyond the capabilities of the inliner and the shift tracer.
new_arg_exprs: dict[ir.Sym, ir.Expr] = {}
new_arg_exprs: dict[eve.SymbolName, ir.Expr] = {}
inlined_args = []
for i, (arg, eligible) in enumerate(zip(node.args, eligible_lifted_args)):
if eligible:
Expand All @@ -239,7 +237,7 @@ def visit_FunCall(
inlined_args.append(inlined_arg)
else:
if isinstance(arg, ir.SymRef):
new_arg_sym = ir.Sym(id=arg.id)
new_arg_sym = arg.id
else:
new_arg_sym = _generate_unique_symbol(
desired_name=(stencil, i),
Expand All @@ -248,7 +246,7 @@ def visit_FunCall(
)

new_arg_exprs[new_arg_sym] = arg
inlined_args.append(ir.SymRef(id=new_arg_sym.id))
inlined_args.append(ir.SymRef(id=new_arg_sym))

inlined_call = self.visit(
inline_lambda(
Expand Down
Loading

0 comments on commit 9c9c8ae

Please sign in to comment.