Skip to content

Commit

Permalink
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
Browse files Browse the repository at this point in the history
…tir, and from itir to the SDFG): Preserve Location through Visitors [WIP]
  • Loading branch information
kotsaloscv committed Dec 14, 2023
1 parent 1ed9764 commit 371dc36
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(copy_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

assert lowered.id == "copy_field"
assert lowered.expr == im.ref("inp")
Expand All @@ -70,6 +71,7 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim],

parsed = FieldOperatorParser.apply_to_function(scalar_arg)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("multiplies")(
"alpha", "bar"
Expand All @@ -84,6 +86,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]

parsed = FieldOperatorParser.apply_to_function(multicopy)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2")

Expand All @@ -96,6 +99,7 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64

parsed = FieldOperatorParser.apply_to_function(arithmetic)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2")

Expand All @@ -110,6 +114,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]):

parsed = FieldOperatorParser.apply_to_function(shift_by_one)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp")

Expand All @@ -124,6 +129,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]):

parsed = FieldOperatorParser.apply_to_function(shift_by_one)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp")

Expand All @@ -139,6 +145,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(copy_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.let(
itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp"
Expand All @@ -165,6 +172,7 @@ def unary(inp: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(unary)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.let(
itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"),
Expand Down Expand Up @@ -194,6 +202,7 @@ def unpacking(

parsed = FieldOperatorParser.apply_to_function(unpacking)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2")
tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0")
Expand Down Expand Up @@ -223,6 +232,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(copy_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0))

Expand All @@ -247,6 +257,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]:

parsed = FieldOperatorParser.apply_to_function(call)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp")

Expand All @@ -262,6 +273,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]):

parsed = FieldOperatorParser.apply_to_function(temp_tuple)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b")
reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0))
Expand All @@ -275,6 +287,7 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]):

parsed = FieldOperatorParser.apply_to_function(unary_not)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("not_")("cond")

Expand All @@ -287,6 +300,7 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(plus)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("plus")("a", "b")

Expand All @@ -299,6 +313,7 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6

parsed = FieldOperatorParser.apply_to_function(scalar_plus_field)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("plus")(
im.promote_to_const_iterator(im.literal("2.0", "float64")), "a"
Expand All @@ -314,6 +329,7 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3

parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.let(
ssa.unique_name("tmp", 0),
Expand All @@ -332,6 +348,7 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(mult)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("multiplies")("a", "b")

Expand All @@ -344,6 +361,7 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(minus)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("minus")("a", "b")

Expand All @@ -356,6 +374,7 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(division)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("divides")("a", "b")

Expand All @@ -368,6 +387,7 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]):

parsed = FieldOperatorParser.apply_to_function(bit_and)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("and_")("a", "b")

Expand All @@ -380,6 +400,7 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]:

parsed = FieldOperatorParser.apply_to_function(scalar_and)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("and_")(
"a", im.promote_to_const_iterator(im.literal("False", "bool"))
Expand All @@ -394,6 +415,7 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]):

parsed = FieldOperatorParser.apply_to_function(bit_or)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("or_")("a", "b")

Expand All @@ -406,6 +428,7 @@ def comp_scalars() -> bool:

parsed = FieldOperatorParser.apply_to_function(comp_scalars)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("greater")(
im.promote_to_const_iterator(im.literal("3", "int32")),
Expand All @@ -421,6 +444,7 @@ def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(comp_gt)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("greater")("a", "b")

Expand All @@ -433,6 +457,7 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(comp_lt)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("less")("a", "b")

Expand All @@ -445,6 +470,7 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]):

parsed = FieldOperatorParser.apply_to_function(comp_eq)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("eq")("a", "b")

Expand All @@ -459,6 +485,7 @@ def compare_chain(

parsed = FieldOperatorParser.apply_to_function(compare_chain)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("and_")(
im.promote_to_lifted_stencil("greater")("a", "b"),
Expand All @@ -474,6 +501,7 @@ def reduction(edge_f: gtx.Field[[Edge], float64]):

parsed = FieldOperatorParser.apply_to_function(reduction)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil(
im.call(
Expand All @@ -496,6 +524,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl

parsed = FieldOperatorParser.apply_to_function(reduction)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))(
im.promote_to_lifted_stencil("make_const_list")(
Expand Down Expand Up @@ -539,6 +568,7 @@ def int_constrs() -> (

parsed = FieldOperatorParser.apply_to_function(int_constrs)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("make_tuple")(
im.promote_to_const_iterator(im.literal("1", "int32")),
Expand Down Expand Up @@ -575,6 +605,7 @@ def float_constrs() -> (

parsed = FieldOperatorParser.apply_to_function(float_constrs)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("make_tuple")(
im.promote_to_const_iterator(im.literal("0.1", "float64")),
Expand All @@ -595,6 +626,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]:

parsed = FieldOperatorParser.apply_to_function(bool_constrs)
lowered = FieldOperatorLowering.apply(parsed)
lowered.location = None

reference = im.promote_to_lifted_stencil("make_tuple")(
im.promote_to_const_iterator(im.literal(str(True), "bool")),
Expand Down

0 comments on commit 371dc36

Please sign in to comment.