diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 962d175eb1..e2aa98f3cf 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -398,6 +398,10 @@ def visit_Assert(self, node: ast.Assert): # Assertions are removed in the AssertionChecker later. return node + def visit_While(self, node: ast.While): + node.body = self._process_stmts(node.body) + return node + def visit_Assign(self, node: ast.Assign): if ( isinstance(node.value, ast.Call) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 976f9a89af..398e312af3 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -716,3 +716,28 @@ def test(out: Field[np.float64], inp: GlobalTable[F64_VEC4]): out = gt_storage.zeros(backend=backend, shape=(2, 2, 2), dtype=np.float64) test(out, inp) assert (out[:] == 42).all() + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_function_inline_in_while(backend): + @gtscript.function + def add_42(v): + return v + 42 + + @gtscript.stencil(backend=backend) + def test( + in_field: Field[np.float64], + out_field: Field[np.float64], + ): + with computation(PARALLEL), interval(...): + count = 1 + while count < 10: + sa = add_42(out_field) + out_field = in_field + sa + count = count + 1 + + domain = (5, 5, 2) + in_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64) + out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64) + test(in_arr, out_arr) + assert (out_arr[:, :, :] == 388.0).all()