Skip to content

Commit

Permalink
bug[next]: foast2gtir lowering of broadcasted field (#1701)
Browse files Browse the repository at this point in the history
Wrap every broadcast in an `as_fieldop` (not only scalars). The
materialization of intermediate broadcasted fields need to be optimized
by transformations.
  • Loading branch information
havogt authored Oct 23, 2024
1 parent 4eb4d4d commit 5f9891e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 1 addition & 3 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,7 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall:

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
expr = self.visit(node.args[0], **kwargs)
if isinstance(node.args[0].type, ts.ScalarType):
return im.as_fieldop(im.ref("deref"))(expr)
return expr
return im.as_fieldop(im.ref("deref"))(expr)

def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def foo(inp: gtx.Field[[TDim], float64]):
lowered = FieldOperatorLowering.apply(parsed)

assert lowered.id == "foo"
assert lowered.expr == im.ref("inp")
assert lowered.expr == im.as_fieldop("deref")(im.ref("inp"))


def test_scalar_broadcast():
Expand Down

0 comments on commit 5f9891e

Please sign in to comment.