Skip to content

Commit

Permalink
feat[next]: Change interval syntax in ITIR pretty printer (#1766)
Browse files Browse the repository at this point in the history
We currently use `)` in the pretty printer to express an open interval.
This is quite cumbersome when debugging the IR because it breaks
matching parenthesis in the editor of functions and calls, e.g. when
does a function start and end. This PR simply uses `[` instead.
  • Loading branch information
tehrengruber authored Dec 10, 2024
1 parent 29b6af2 commit 9888905
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,11 @@ def domain(
... },
... )
... )
'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩'
'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
>>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)}))
'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩'
'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
>>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)}))
'u⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩'
'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
"""
if isinstance(grid_type, common.GridType):
grid_type = f"{grid_type!s}_domain"
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
else_branch_seperator: "else"
if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}"
named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")"
named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 "["
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def visit_FunCall(self, node: ir.FunCall, *, prec: int) -> list[str]:
if fun_name == "named_range" and len(node.args) == 3:
# named_range(dim, start, stop) → dim: [star, stop)
dim, start, end = self.visit(node.args, prec=0)
res = self._hmerge(dim, [": ["], start, [", "], end, [")"])
res = self._hmerge(
dim, [": ["], start, [", "], end, ["["]
) # to get matching parenthesis of functions
return self._prec_parens(res, prec, PRECEDENCE["__call__"])
if fun_name == "cartesian_domain" and len(node.args) >= 1:
# cartesian_domain(x, y, ...) → c{ x × y × ... } # noqa: RUF003 [ambiguous-unicode-character-comment]
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ class FuseAsFieldOp(eve.NodeTranslator):
... im.ref("inp3", field_type),
... )
>>> print(nested_as_fieldop)
as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(
as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3
as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)(
as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2), inp3
)
>>> print(
... FuseAsFieldOp.apply(
... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True
... )
... )
as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3)
as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3)
""" # noqa: RUF002 # ignore ambiguous multiplication character

uids: eve_utils.UIDGenerator
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/inline_fundefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program:
>>> print(prune_unreferenced_fundefs(program))
testee(inp, out) {
fun1 = λ(a) → ·a;
out @ c⟨ IDimₕ: [0, 10) ⟩ ← fun1(inp);
out @ c⟨ IDimₕ: [0, 10[ ⟩ ← fun1(inp);
}
"""
fun_names = [fun.id for fun in program.function_definitions]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_make_tuple():


def test_named_range_horizontal():
testee = "IDimₕ: [x, y)"
testee = "IDimₕ: [x, y["
expected = ir.FunCall(
fun=ir.SymRef(id="named_range"),
args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")],
Expand All @@ -137,7 +137,7 @@ def test_named_range_horizontal():


def test_named_range_vertical():
testee = "IDimᵥ: [x, y)"
testee = "IDimᵥ: [x, y["
expected = ir.FunCall(
fun=ir.SymRef(id="named_range"),
args=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_named_range_horizontal():
fun=ir.SymRef(id="named_range"),
args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")],
)
expected = "IDimₕ: [x, y)"
expected = "IDimₕ: [x, y["
actual = pformat(testee)
assert actual == expected

Expand Down

0 comments on commit 9888905

Please sign in to comment.