Skip to content

Commit

Permalink
refactor[next]: Remove GTIR cond and use if_ everywhere (#1665)
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber authored Sep 27, 2024
1 parent fb1d494 commit ac513a1
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 65 deletions.
15 changes: 7 additions & 8 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ class FieldOperatorLowering(eve.PreserveLocationVisitor, eve.NodeTranslator):
Lower FieldOperator AST (FOAST) to GTIR.
Most expressions are lowered to `as_fieldop`ed stencils.
Pure scalar expressions are kept as scalar operations as they might appear outside of the stencil context,
e.g. in `cond`.
In arithemtic operations that involve a field and a scalar, the scalar is implicitly broadcasted to a field
in the `as_fieldop` call.
Pure scalar expressions are kept as scalar operations as they might appear outside of the
stencil context. In arithmetic operations that involve a field and a scalar, the scalar is
implicitly broadcasted to a field in the `as_fieldop` call.
Examples
--------
Expand Down Expand Up @@ -164,7 +163,7 @@ def visit_IfStmt(
inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr)

# here we assume neither branch returns
return im.let("__if_stmt_result", im.cond(cond, true_branch, false_branch))(inner_expr)
return im.let("__if_stmt_result", im.if_(cond, true_branch, false_branch))(inner_expr)
elif return_kind is foast_introspection.StmtReturnKind.CONDITIONAL_RETURN:
common_syms = tuple(im.sym(sym) for sym in common_symbols.keys())
common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys())
Expand All @@ -179,7 +178,7 @@ def visit_IfStmt(
false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs)

return im.let(inner_expr_name, inner_expr_evaluator)(
im.cond(cond, true_branch, false_branch)
im.if_(cond, true_branch, false_branch)
)

assert return_kind is foast_introspection.StmtReturnKind.UNCONDITIONAL_RETURN
Expand All @@ -189,7 +188,7 @@ def visit_IfStmt(
true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs)
false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs)

return im.cond(cond, true_branch, false_branch)
return im.if_(cond, true_branch, false_branch)

def visit_Assign(
self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any
Expand Down Expand Up @@ -232,7 +231,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
isinstance(node.condition.type, ts.ScalarType)
and node.condition.type.kind == ts.ScalarKind.BOOL
)
return im.cond(
return im.if_(
self.visit(node.condition), self.visit(node.true_expr), self.visit(node.false_expr)
)

Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
GTIR_BUILTINS = {
*BUILTINS,
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
"cond", # `cond(expr, field_a, field_b)` creates the field on one branch or the other
}


Expand Down
5 changes: 0 additions & 5 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,6 @@ def if_(cond, true_val, false_val):
return call("if_")(cond, true_val, false_val)


def cond(cond, true_val, false_val):
"""Create a cond FunCall, shorthand for ``call("cond")(expr)``."""
return call("cond")(cond, true_val, false_val)


def lift(expr):
"""Create a lift FunCall, shorthand for ``call(call("lift")(expr))``."""
return call(call("lift")(expr))
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ def infer_tuple_get(
return infered_args_expr, actual_domains


def infer_cond(
def infer_if(
expr: itir.Expr,
domain: DOMAIN,
offset_provider: common.OffsetProvider,
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
assert cpm.is_call_to(expr, "cond")
assert cpm.is_call_to(expr, "if_")
infered_args_expr = []
actual_domains: ACCESSED_DOMAINS = {}
cond, true_val, false_val = expr.args
Expand Down Expand Up @@ -293,8 +293,8 @@ def infer_expr(
return infer_make_tuple(expr, domain, offset_provider)
elif cpm.is_call_to(expr, "tuple_get"):
return infer_tuple_get(expr, domain, offset_provider)
elif cpm.is_call_to(expr, "cond"):
return infer_cond(expr, domain, offset_provider)
elif cpm.is_call_to(expr, "if_"):
return infer_if(expr, domain, offset_provider)
elif (
cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS)
or cpm.is_call_to(expr, itir.TYPEBUILTINS)
Expand Down
43 changes: 17 additions & 26 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def _register_builtin_type_synthesizer(
*,
fun_names: Optional[Iterable[str]] = None,
):
def wrapper(synthesizer: Callable[..., TypeOrTypeSynthesizer]) -> None:
# store names in function object for better debuggability
synthesizer.fun_names = fun_names or [synthesizer.__name__] # type: ignore[attr-defined]
for f in synthesizer.fun_names: # type: ignore[attr-defined]
builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer)
if synthesizer is None:
return functools.partial(_register_builtin_type_synthesizer, fun_names=fun_names)

return wrapper(synthesizer) if synthesizer else wrapper
# store names in function object for better debuggability
synthesizer.fun_names = fun_names or [synthesizer.__name__] # type: ignore[attr-defined]
for f in synthesizer.fun_names: # type: ignore[attr-defined]
builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer)
return synthesizer


@_register_builtin_type_synthesizer(
Expand Down Expand Up @@ -136,12 +137,20 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType:


@_register_builtin_type_synthesizer
def if_(cond: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType:
assert isinstance(cond, ts.ScalarType) and cond.kind == ts.ScalarKind.BOOL
def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType:
if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType):
return tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda elts: ts.TupleType(types=[*elts]),
)(functools.partial(if_, pred))(true_branch, false_branch)

assert not isinstance(true_branch, ts.TupleType) and not isinstance(false_branch, ts.TupleType)
assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL
# TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are
# iterators defined on different positions this fails. For the GTFN backend we also don't
# want this, but for roundtrip it is totally fine.
# assert true_branch == false_branch # noqa: ERA001

return true_branch


Expand Down Expand Up @@ -278,24 +287,6 @@ def applied_as_fieldop(*fields) -> ts.FieldType:
return applied_as_fieldop


@_register_builtin_type_synthesizer
def cond(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType:
def type_synthesizer_per_element(
pred: ts.ScalarType,
true_branch: ts.DataType,
false_branch: ts.DataType,
):
assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL
assert true_branch == false_branch

return true_branch

return tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda elts: ts.TupleType(types=[*elts]),
)(functools.partial(type_synthesizer_per_element, pred))(true_branch, false_branch)


@_register_builtin_type_synthesizer
def scan(
scan_pass: TypeSynthesizer, direction: ts.ScalarType, init: ts.ScalarType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,15 @@ def translate_as_field_op(
return [(field_node, field_type)]


def translate_cond(
def translate_if(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `cond` builtin function."""
assert cpm.is_call_to(node, "cond")
"""Generates the dataflow subgraph for the `if_` builtin function (outside of `as_fieldop`)."""
assert cpm.is_call_to(node, "if_")
assert len(node.args) == 3
cond_expr, true_expr, false_expr = node.args

Expand Down Expand Up @@ -490,7 +490,7 @@ def translate_symbol_ref(
# Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol
__primitive_translators: list[PrimitiveTranslator] = [
translate_as_field_op,
translate_cond,
translate_if,
translate_literal,
translate_scalar_expr,
translate_symbol_ref,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def visit_FunCall(
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[gtir_builtin_translators.TemporaryData]:
# use specialized dataflow builder classes for each builtin function
if cpm.is_call_to(node, "cond"):
return gtir_builtin_translators.translate_cond(
if cpm.is_call_to(node, "if_"):
return gtir_builtin_translators.translate_if(
node, sdfg, head_state, self, reduce_identity
)
elif cpm.is_call_to(node.fun, "as_fieldop"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]):
parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.cond("a", "b", "c")
reference = im.if_("a", "b", "c")

assert lowered.expr == reference

Expand All @@ -234,7 +234,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]):
parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.cond("a", "b", "c")
reference = im.if_("a", "b", "c")

assert lowered.expr == reference

Expand All @@ -252,7 +252,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]):
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined)

reference = im.tuple_get(0, im.cond("a", im.make_tuple("b"), im.make_tuple("c")))
reference = im.tuple_get(0, im.if_("a", im.make_tuple("b"), im.make_tuple("c")))

assert lowered_inlined.expr == reference

Expand All @@ -272,7 +272,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]):
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined)

reference = im.cond("a", "b", im.cond("a", "c", "b"))
reference = im.if_("a", "b", im.if_("a", "c", "b"))

assert lowered_inlined.expr == reference

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def expression_test_cases():
)(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)),
ts.TupleType(types=[float_i_field, float_i_field]),
),
# cond
# if in field-view scope
(
im.cond(
im.if_(
False,
im.call(
im.call("as_fieldop")(
Expand All @@ -195,7 +195,7 @@ def expression_test_cases():
float_i_field,
),
(
im.cond(
im.if_(
False,
im.make_tuple(im.ref("inp", float_i_field), im.ref("inp", float_i_field)),
im.make_tuple(im.ref("inp", float_i_field), im.ref("inp", float_i_field)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def test_cond(offset_provider):

cond = im.deref("cond_")

testee = im.cond(cond, field_1, field_2)
testee = im.if_(cond, field_1, field_2)

domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)})
domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider)
Expand All @@ -515,7 +515,7 @@ def test_cond(offset_provider):
expected_field_1 = im.as_fieldop(stencil1, domain)(im.ref("in_field1"))
expected_field_2 = im.as_fieldop(stencil2, domain)(im.ref("in_field2"), expected_tmp2)

expected = im.cond(cond, expected_field_1, expected_field_2)
expected = im.if_(cond, expected_field_1, expected_field_2)

actual_call, actual_domains = infer_domain.infer_expr(
testee, SymbolicDomain.from_expr(domain), offset_provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_gtir_cond():
gtir.SetAt(
expr=im.op_as_fieldop("plus", domain)(
"x",
im.cond(
im.if_(
im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")),
im.op_as_fieldop("plus", domain)("y", "scalar"),
im.op_as_fieldop("plus", domain)("w", "scalar"),
Expand Down Expand Up @@ -379,10 +379,10 @@ def test_gtir_cond_nested():
declarations=[],
body=[
gtir.SetAt(
expr=im.cond(
expr=im.if_(
gtir.SymRef(id="pred_1"),
im.op_as_fieldop("plus", domain)("x", 1.0),
im.cond(
im.if_(
gtir.SymRef(id="pred_2"),
im.op_as_fieldop("plus", domain)("x", 2.0),
im.op_as_fieldop("plus", domain)("x", 3.0),
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def test_gtir_reduce_with_cond_neighbors():
),
vertex_domain,
)(
im.cond(
im.if_(
gtir.SymRef(id="pred"),
im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain),
im.as_fieldop_neighbors("V2E", "edges", vertex_domain),
Expand Down Expand Up @@ -1360,7 +1360,7 @@ def test_gtir_let_lambda_with_cond():
gtir.SetAt(
expr=im.let("x1", "x")(
im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))(
im.cond(
im.if_(
gtir.SymRef(id="pred"),
im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"),
im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"),
Expand Down Expand Up @@ -1402,7 +1402,7 @@ def test_gtir_if_scalars():
gtir.SetAt(
expr=im.op_as_fieldop("plus", domain)(
"x",
im.cond(
im.if_(
"pred",
im.call("cast_")("y_0", "float64"),
im.call("cast_")("y_1", "float64"),
Expand Down

0 comments on commit ac513a1

Please sign in to comment.