From 6fb28a104ad36138c4320b9957aef819266478b1 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 14:33:41 +0100 Subject: [PATCH] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/eve/visitors.py | 8 +++++++- src/gt4py/next/ffront/foast_to_itir.py | 10 ++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index c3b9f3abf3..c0a0054f5a 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -199,8 +199,14 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: class PreserveLocationVisitor(NodeVisitor): + preserve_location: bool = True + + def __init__(self, preserve_location: bool = True) -> None: + super().__init__() + self.preserve_location = preserve_location + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location"): + if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location: result.location = node.location return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index ba771ed5b5..4a88553532 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -15,7 +15,7 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator, concepts, extended_typing +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next.ffront import ( dialect_ast_enums, @@ -39,7 +39,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(NodeTranslator): +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -72,12 +72,6 @@ class FieldOperatorLowering(NodeTranslator): def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr: return cls(preserve_location=preserve_location).visit(node) - def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any: - result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location: - result.location = node.location - return result - def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs ) -> itir.FunctionDefinition: