Skip to content

Commit

Permalink
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
Browse files Browse the repository at this point in the history
…tir, and from itir to the SDFG): Preserve Location through Visitors [WIP]
  • Loading branch information
kotsaloscv committed Dec 14, 2023
1 parent bb880dd commit e0a254f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 72 deletions.
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,19 @@ class FieldOperatorLowering(NodeTranslator):
"""

uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator)
preserve_location: bool = True

@classmethod
def apply(cls, node: foast.LocatedNode) -> itir.Expr:
return cls().visit(node)
def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr:
return cls(preserve_location=preserve_location).visit(node)

def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any:
result = super().visit(node, **kwargs)
if (
hasattr(node, "location")
and hasattr(result, "location")
and not isinstance(node, foast.Name)
and self.preserve_location
):
result.location = node.location
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import hashlib
import warnings
from inspect import currentframe, getframeinfo
from typing import Any, Mapping, Optional, Sequence

import dace
Expand Down Expand Up @@ -221,15 +222,15 @@ def build_sdfg_from_itir(
sdfg = sdfg_genenerator.visit(program)
for nested_sdfg in sdfg.all_sdfgs_recursive():
if not nested_sdfg.debuginfo:
warnings.warn(
_, frameinfo = warnings.warn(
f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg."
), getframeinfo(
currentframe() # type: ignore
)
nested_sdfg.debuginfo = dace.dtypes.DebugInfo(
start_line=0,
start_column=0,
end_line=-1,
end_column=0,
filename=None,
start_line=frameinfo.lineno,
end_line=frameinfo.lineno,
filename=frameinfo.filename,
)
sdfg.simplify()

Expand Down
96 changes: 32 additions & 64 deletions tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]):
return inp

parsed = FieldOperatorParser.apply_to_function(copy_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

assert lowered.id == "copy_field"
assert lowered.expr == im.ref("inp")
Expand All @@ -70,8 +69,7 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim],
return alpha * bar

parsed = FieldOperatorParser.apply_to_function(scalar_arg)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("multiplies")(
"alpha", "bar"
Expand All @@ -85,8 +83,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]
return inp1, inp2

parsed = FieldOperatorParser.apply_to_function(multicopy)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2")

Expand All @@ -98,8 +95,7 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64
return inp1 + inp2

parsed = FieldOperatorParser.apply_to_function(arithmetic)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2")

Expand All @@ -113,8 +109,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]):
return inp(Ioff[1])

parsed = FieldOperatorParser.apply_to_function(shift_by_one)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp")

Expand All @@ -128,8 +123,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]):
return inp(Ioff[-1])

parsed = FieldOperatorParser.apply_to_function(shift_by_one)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp")

Expand All @@ -144,8 +138,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]):
return tmp2

parsed = FieldOperatorParser.apply_to_function(copy_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.let(
itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp"
Expand All @@ -171,8 +164,7 @@ def unary(inp: gtx.Field[[TDim], float64]):
return tmp

parsed = FieldOperatorParser.apply_to_function(unary)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.let(
itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"),
Expand Down Expand Up @@ -201,8 +193,7 @@ def unpacking(
return tmp1

parsed = FieldOperatorParser.apply_to_function(unpacking)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2")
tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0")
Expand Down Expand Up @@ -231,8 +222,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]):
return tmp

parsed = FieldOperatorParser.apply_to_function(copy_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0))

Expand All @@ -256,8 +246,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]:
return identity(inp)

parsed = FieldOperatorParser.apply_to_function(call)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp")

Expand All @@ -272,8 +261,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]):
return tmp

parsed = FieldOperatorParser.apply_to_function(temp_tuple)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b")
reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0))
Expand All @@ -286,8 +274,7 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]):
return not cond

parsed = FieldOperatorParser.apply_to_function(unary_not)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("not_")("cond")

Expand All @@ -299,8 +286,7 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):
return a + b

parsed = FieldOperatorParser.apply_to_function(plus)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("plus")("a", "b")

Expand All @@ -312,8 +298,7 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6
return 2.0 + a

parsed = FieldOperatorParser.apply_to_function(scalar_plus_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("plus")(
im.promote_to_const_iterator(im.literal("2.0", "float64")), "a"
Expand All @@ -328,8 +313,7 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3
return a + tmp

parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.let(
ssa.unique_name("tmp", 0),
Expand All @@ -347,8 +331,7 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):
return a * b

parsed = FieldOperatorParser.apply_to_function(mult)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("multiplies")("a", "b")

Expand All @@ -360,8 +343,7 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):
return a - b

parsed = FieldOperatorParser.apply_to_function(minus)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("minus")("a", "b")

Expand All @@ -373,8 +355,7 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):
return a / b

parsed = FieldOperatorParser.apply_to_function(division)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("divides")("a", "b")

Expand All @@ -386,8 +367,7 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]):
return a & b

parsed = FieldOperatorParser.apply_to_function(bit_and)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("and_")("a", "b")

Expand All @@ -399,8 +379,7 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]:
return a & False

parsed = FieldOperatorParser.apply_to_function(scalar_and)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("and_")(
"a", im.promote_to_const_iterator(im.literal("False", "bool"))
Expand All @@ -414,8 +393,7 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]):
return a | b

parsed = FieldOperatorParser.apply_to_function(bit_or)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("or_")("a", "b")

Expand All @@ -427,8 +405,7 @@ def comp_scalars() -> bool:
return 3 > 4

parsed = FieldOperatorParser.apply_to_function(comp_scalars)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("greater")(
im.promote_to_const_iterator(im.literal("3", "int32")),
Expand All @@ -443,8 +420,7 @@ def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):
return a > b

parsed = FieldOperatorParser.apply_to_function(comp_gt)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("greater")("a", "b")

Expand All @@ -456,8 +432,7 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):
return a < b

parsed = FieldOperatorParser.apply_to_function(comp_lt)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("less")("a", "b")

Expand All @@ -469,8 +444,7 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]):
return a == b

parsed = FieldOperatorParser.apply_to_function(comp_eq)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("eq")("a", "b")

Expand All @@ -484,8 +458,7 @@ def compare_chain(
return a > b > c

parsed = FieldOperatorParser.apply_to_function(compare_chain)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("and_")(
im.promote_to_lifted_stencil("greater")("a", "b"),
Expand All @@ -500,8 +473,7 @@ def reduction(edge_f: gtx.Field[[Edge], float64]):
return neighbor_sum(edge_f(V2E), axis=V2EDim)

parsed = FieldOperatorParser.apply_to_function(reduction)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil(
im.call(
Expand All @@ -523,8 +495,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl
return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim)

parsed = FieldOperatorParser.apply_to_function(reduction)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))(
im.promote_to_lifted_stencil("make_const_list")(
Expand Down Expand Up @@ -567,8 +538,7 @@ def int_constrs() -> (
return 1, int32(1), int64(1), int32("1"), int64("1")

parsed = FieldOperatorParser.apply_to_function(int_constrs)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("make_tuple")(
im.promote_to_const_iterator(im.literal("1", "int32")),
Expand Down Expand Up @@ -604,8 +574,7 @@ def float_constrs() -> (
)

parsed = FieldOperatorParser.apply_to_function(float_constrs)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("make_tuple")(
im.promote_to_const_iterator(im.literal("0.1", "float64")),
Expand All @@ -625,8 +594,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]:
return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False")

parsed = FieldOperatorParser.apply_to_function(bool_constrs)
lowered = FieldOperatorLowering.apply(parsed)
lowered.expr.location = None
lowered = FieldOperatorLowering.apply(parsed, preserve_location=False)

reference = im.promote_to_lifted_stencil("make_tuple")(
im.promote_to_const_iterator(im.literal(str(True), "bool")),
Expand Down

0 comments on commit e0a254f

Please sign in to comment.