Skip to content

Commit

Permalink
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
Browse files Browse the repository at this point in the history
…tir, and from itir to the SDFG): Preserve Location through Visitors
  • Loading branch information
kotsaloscv committed Jan 12, 2024
1 parent 9c9c8ae commit 6fb28a1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
8 changes: 7 additions & 1 deletion src/gt4py/eve/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,14 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:


class PreserveLocationVisitor(NodeVisitor):
preserve_location: bool = True

def __init__(self, preserve_location: bool = True) -> None:
super().__init__()
self.preserve_location = preserve_location

def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location"):
if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location:
result.location = node.location
return result
10 changes: 2 additions & 8 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
from typing import Any, Callable, Optional

from gt4py.eve import NodeTranslator, concepts, extended_typing
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.utils import UIDGenerator
from gt4py.next.ffront import (
dialect_ast_enums,
Expand All @@ -39,7 +39,7 @@ def promote_to_list(


@dataclasses.dataclass
class FieldOperatorLowering(NodeTranslator):
class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator):
"""
Lower FieldOperator AST (FOAST) to Iterator IR (ITIR).
Expand Down Expand Up @@ -72,12 +72,6 @@ class FieldOperatorLowering(NodeTranslator):
def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr:
return cls(preserve_location=preserve_location).visit(node)

def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any:
result = super().visit(node, **kwargs)
if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location:
result.location = node.location
return result

def visit_FunctionDefinition(
self, node: foast.FunctionDefinition, **kwargs
) -> itir.FunctionDefinition:
Expand Down

0 comments on commit 6fb28a1

Please sign in to comment.