Skip to content

Commit

Permalink
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
Browse files Browse the repository at this point in the history
…tir, and from itir to the SDFG): Preserve Location through Visitors [WIP]
  • Loading branch information
kotsaloscv committed Dec 13, 2023
1 parent 5632def commit b59fd83
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
12 changes: 10 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import operator
import typing

from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
SymbolTableTrait,
VisitorWithSymbolTableTrait,
traits,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.eve.visitors import PreserveLocation
from gt4py.next.iterator import ir
Expand Down Expand Up @@ -73,7 +79,9 @@ def _is_collectable_expr(node: ir.Node) -> bool:


@dataclasses.dataclass
class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
class CollectSubexpressions(
traits.PreserveLocationWithSymbolTableTrait, VisitorWithSymbolTableTrait, NodeVisitor
):
@dataclasses.dataclass
class SubexpressionData:
#: A list of node ids with equal hash and a set of collected child subexpression ids
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/iterator/transforms/eta_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node:
for p, a in zip(node.params, node.expr.args)
)
):
node.expr.fun.location = node.location
return self.visit(node.expr.fun)

return self.generic_visit(node)
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator):
class FuseMaps(
traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator
):
"""
Fuses nested `map_`s.
Expand Down Expand Up @@ -125,12 +127,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]),
args=new_args,
location=node.location,
)
else: # _is_reduce(node)
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]),
args=new_args,
location=node.location,
)
return node
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
stencil=stencil,
output=im.ref(tmp_sym.id),
inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined]
location=current_closure.location,
)
)

Expand Down Expand Up @@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
output=current_closure.output,
inputs=current_closure.inputs
+ [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()],
location=current_closure.location,
)
)
else:
Expand All @@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
+ [ir.Sym(id=tmp.id) for tmp in tmps]
+ [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant
closures=list(reversed(closures)),
location=node.location,
),
params=node.params,
tmps=[Temporary(id=tmp.id) for tmp in tmps],
Expand All @@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari
function_definitions=node.fencil.function_definitions,
params=[p for p in node.fencil.params if p.id not in unused_tmps],
closures=closures,
location=node.fencil.location,
),
params=node.params,
tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps],
Expand Down Expand Up @@ -453,6 +457,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
stencil=closure.stencil,
output=closure.output,
inputs=closure.inputs,
location=closure.location,
)
else:
domain = closure.domain
Expand Down Expand Up @@ -505,6 +510,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
function_definitions=node.fencil.function_definitions,
params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again
closures=list(reversed(closures)),
location=node.fencil.location,
),
params=node.params,
tmps=node.tmps,
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/iterator/transforms/inline_lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def visit_FunCall(
ir.FunCall(
fun=self.generic_visit(node.fun, is_scan_pass_context=_is_scan(node), **kwargs),
args=self.generic_visit(node.args, **kwargs),
location=node.location,
)
if recurse
else node
Expand Down

0 comments on commit b59fd83

Please sign in to comment.