diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2e007c28bc..948a8481d7 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -65,10 +65,9 @@ class FieldOperatorLowering(eve.PreserveLocationVisitor, eve.NodeTranslator): Lower FieldOperator AST (FOAST) to GTIR. Most expressions are lowered to `as_fieldop`ed stencils. - Pure scalar expressions are kept as scalar operations as they might appear outside of the stencil context, - e.g. in `cond`. - In arithemtic operations that involve a field and a scalar, the scalar is implicitly broadcasted to a field - in the `as_fieldop` call. + Pure scalar expressions are kept as scalar operations as they might appear outside of the + stencil context. In arithmetic operations that involve a field and a scalar, the scalar is + implicitly broadcasted to a field in the `as_fieldop` call. Examples -------- @@ -164,7 +163,7 @@ def visit_IfStmt( inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) # here we assume neither branch returns - return im.let("__if_stmt_result", im.cond(cond, true_branch, false_branch))(inner_expr) + return im.let("__if_stmt_result", im.if_(cond, true_branch, false_branch))(inner_expr) elif return_kind is foast_introspection.StmtReturnKind.CONDITIONAL_RETURN: common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) @@ -179,7 +178,7 @@ def visit_IfStmt( false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) return im.let(inner_expr_name, inner_expr_evaluator)( - im.cond(cond, true_branch, false_branch) + im.if_(cond, true_branch, false_branch) ) assert return_kind is foast_introspection.StmtReturnKind.UNCONDITIONAL_RETURN @@ -189,7 +188,7 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - return im.cond(cond, true_branch, false_branch) + return im.if_(cond, true_branch, false_branch) def visit_Assign( self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any @@ -232,7 +231,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC isinstance(node.condition.type, ts.ScalarType) and node.condition.type.kind == ts.ScalarKind.BOOL ) - return im.cond( + return im.if_( self.visit(node.condition), self.visit(node.true_expr), self.visit(node.false_expr) ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 28adaaddf1..43e249f5b1 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -188,7 +188,6 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib GTIR_BUILTINS = { *BUILTINS, "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) - "cond", # `cond(expr, field_a, field_b)` creates the field on one branch or the other } diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index a7dc201db9..b2662fa278 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -247,11 +247,6 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) -def cond(cond, true_val, false_val): - """Create a cond FunCall, shorthand for ``call("cond")(expr)``.""" - return call("cond")(cond, true_val, false_val) - - def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 87f754d644..5104d09d3a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -260,12 +260,12 @@ def infer_tuple_get( return infered_args_expr, actual_domains -def infer_cond( +def infer_if( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: - assert cpm.is_call_to(expr, "cond") + assert cpm.is_call_to(expr, "if_") infered_args_expr = [] actual_domains: ACCESSED_DOMAINS = {} cond, true_val, false_val = expr.args @@ -293,8 +293,8 @@ def infer_expr( return infer_make_tuple(expr, domain, offset_provider) elif cpm.is_call_to(expr, "tuple_get"): return infer_tuple_get(expr, domain, offset_provider) - elif cpm.is_call_to(expr, "cond"): - return infer_cond(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "if_"): + return infer_if(expr, domain, offset_provider) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 1f2abadb74..77cd39389a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -70,13 +70,14 @@ def _register_builtin_type_synthesizer( *, fun_names: Optional[Iterable[str]] = None, ): - def wrapper(synthesizer: Callable[..., TypeOrTypeSynthesizer]) -> None: - # store names in function object for better debuggability - synthesizer.fun_names = fun_names or [synthesizer.__name__] # type: ignore[attr-defined] - for f in synthesizer.fun_names: # type: ignore[attr-defined] - builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer) + if synthesizer is None: + return functools.partial(_register_builtin_type_synthesizer, fun_names=fun_names) - return wrapper(synthesizer) if synthesizer else wrapper + # store names in function object for better debuggability + synthesizer.fun_names = fun_names or [synthesizer.__name__] # type: ignore[attr-defined] + for f in synthesizer.fun_names: # type: ignore[attr-defined] + builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer) + return synthesizer @_register_builtin_type_synthesizer( @@ -136,12 +137,20 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType: @_register_builtin_type_synthesizer -def if_(cond: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: - assert isinstance(cond, ts.ScalarType) and cond.kind == ts.ScalarKind.BOOL +def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: + if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): + return tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda elts: ts.TupleType(types=[*elts]), + )(functools.partial(if_, pred))(true_branch, false_branch) + + assert not isinstance(true_branch, ts.TupleType) and not isinstance(false_branch, ts.TupleType) + assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL # TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are # iterators defined on different positions this fails. For the GTFN backend we also don't # want this, but for roundtrip it is totally fine. # assert true_branch == false_branch # noqa: ERA001 + return true_branch @@ -278,24 +287,6 @@ def applied_as_fieldop(*fields) -> ts.FieldType: return applied_as_fieldop -@_register_builtin_type_synthesizer -def cond(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: - def type_synthesizer_per_element( - pred: ts.ScalarType, - true_branch: ts.DataType, - false_branch: ts.DataType, - ): - assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL - assert true_branch == false_branch - - return true_branch - - return tree_map( - collection_type=ts.TupleType, - result_collection_constructor=lambda elts: ts.TupleType(types=[*elts]), - )(functools.partial(type_synthesizer_per_element, pred))(true_branch, false_branch) - - @_register_builtin_type_synthesizer def scan( scan_pass: TypeSynthesizer, direction: ts.ScalarType, init: ts.ScalarType diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e47b46ac09..263ea6381b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -237,15 +237,15 @@ def translate_as_field_op( return [(field_node, field_type)] -def translate_cond( +def translate_if( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[TemporaryData]: - """Generates the dataflow subgraph for the `cond` builtin function.""" - assert cpm.is_call_to(node, "cond") + """Generates the dataflow subgraph for the `if_` builtin function (outside of `as_fieldop`).""" + assert cpm.is_call_to(node, "if_") assert len(node.args) == 3 cond_expr, true_expr, false_expr = node.args @@ -490,7 +490,7 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_field_op, - translate_cond, + translate_if, translate_literal, translate_scalar_expr, translate_symbol_ref, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index c955ac37cd..c490824951 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -342,8 +342,8 @@ def visit_FunCall( reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function - if cpm.is_call_to(node, "cond"): - return gtir_builtin_translators.translate_cond( + if cpm.is_call_to(node, "if_"): + return gtir_builtin_translators.translate_if( node, sdfg, head_state, self, reduce_identity ) elif cpm.is_call_to(node.fun, "as_fieldop"): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 706de8a3eb..3951c410dc 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -219,7 +219,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.cond("a", "b", "c") + reference = im.if_("a", "b", "c") assert lowered.expr == reference @@ -234,7 +234,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.cond("a", "b", "c") + reference = im.if_("a", "b", "c") assert lowered.expr == reference @@ -252,7 +252,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined) - reference = im.tuple_get(0, im.cond("a", im.make_tuple("b"), im.make_tuple("c"))) + reference = im.tuple_get(0, im.if_("a", im.make_tuple("b"), im.make_tuple("c"))) assert lowered_inlined.expr == reference @@ -272,7 +272,7 @@ def foo(a: bool, b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered_inlined) - reference = im.cond("a", "b", im.cond("a", "c", "b")) + reference = im.if_("a", "b", im.if_("a", "c", "b")) assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1709c2e128..acfb1d0bd8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -171,9 +171,9 @@ def expression_test_cases(): )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), ts.TupleType(types=[float_i_field, float_i_field]), ), - # cond + # if in field-view scope ( - im.cond( + im.if_( False, im.call( im.call("as_fieldop")( @@ -195,7 +195,7 @@ def expression_test_cases(): float_i_field, ), ( - im.cond( + im.if_( False, im.make_tuple(im.ref("inp", float_i_field), im.ref("inp", float_i_field)), im.make_tuple(im.ref("inp", float_i_field), im.ref("inp", float_i_field)), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 51932e0aa0..5f078e8c72 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -504,7 +504,7 @@ def test_cond(offset_provider): cond = im.deref("cond_") - testee = im.cond(cond, field_1, field_2) + testee = im.if_(cond, field_1, field_2) domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)}) domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider) @@ -515,7 +515,7 @@ def test_cond(offset_provider): expected_field_1 = im.as_fieldop(stencil1, domain)(im.ref("in_field1")) expected_field_2 = im.as_fieldop(stencil2, domain)(im.ref("in_field2"), expected_tmp2) - expected = im.cond(cond, expected_field_1, expected_field_2) + expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( testee, SymbolicDomain.from_expr(domain), offset_provider diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 7f61b87b40..694cf7318f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -338,7 +338,7 @@ def test_gtir_cond(): gtir.SetAt( expr=im.op_as_fieldop("plus", domain)( "x", - im.cond( + im.if_( im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")), im.op_as_fieldop("plus", domain)("y", "scalar"), im.op_as_fieldop("plus", domain)("w", "scalar"), @@ -379,10 +379,10 @@ def test_gtir_cond_nested(): declarations=[], body=[ gtir.SetAt( - expr=im.cond( + expr=im.if_( gtir.SymRef(id="pred_1"), im.op_as_fieldop("plus", domain)("x", 1.0), - im.cond( + im.if_( gtir.SymRef(id="pred_2"), im.op_as_fieldop("plus", domain)("x", 2.0), im.op_as_fieldop("plus", domain)("x", 3.0), @@ -1168,7 +1168,7 @@ def test_gtir_reduce_with_cond_neighbors(): ), vertex_domain, )( - im.cond( + im.if_( gtir.SymRef(id="pred"), im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), im.as_fieldop_neighbors("V2E", "edges", vertex_domain), @@ -1360,7 +1360,7 @@ def test_gtir_let_lambda_with_cond(): gtir.SetAt( expr=im.let("x1", "x")( im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.cond( + im.if_( gtir.SymRef(id="pred"), im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"), @@ -1402,7 +1402,7 @@ def test_gtir_if_scalars(): gtir.SetAt( expr=im.op_as_fieldop("plus", domain)( "x", - im.cond( + im.if_( "pred", im.call("cast_")("y_0", "float64"), im.call("cast_")("y_1", "float64"),