From 5f9891ed81a01d061c2a988abaca6b5b2ed58c17 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 23 Oct 2024 15:09:43 +0200 Subject: [PATCH] bug[next]: foast2gtir lowering of broadcasted field (#1701) Wrap every broadcast in an `as_fieldop` (not only scalars). The materialization of intermediate broadcasted fields need to be optimized by transformations. --- src/gt4py/next/ffront/foast_to_gtir.py | 4 +--- .../next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 9cb0ce05f5..0d0c3868f8 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -374,9 +374,7 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) - if isinstance(node.args[0].type, ts.ScalarType): - return im.as_fieldop(im.ref("deref"))(expr) - return expr + return im.as_fieldop(im.ref("deref"))(expr) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 09f18246dc..4a1a7cba8e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -915,7 +915,7 @@ def foo(inp: gtx.Field[[TDim], float64]): lowered = FieldOperatorLowering.apply(parsed) assert lowered.id == "foo" - assert lowered.expr == im.ref("inp") + assert lowered.expr == im.as_fieldop("deref")(im.ref("inp")) def test_scalar_broadcast():