diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 13c21eb516..7b96de8e89 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -196,7 +196,11 @@ def where( @builtin_function -def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: +def astype( + field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + type_: type, + /, +) -> Field | Tuple[Field, ...]: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 605b83a5f0..95c9128f87 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: + return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( - value.type, (ts.FieldType, ts.ScalarType) + value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ kind.name for kind in ts.ScalarKind ]: @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) + return_type = type_info.apply_to_primitive_constituents( + value.type, + lambda primitive_type: with_altered_scalar_kind( + primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + ), ) return foast.Call( diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..816b8581f1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, dtype = node.args[0], node.args[1].id - - # TODO check that we test astype that results in a itir.map_ operation - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, + obj, new_type = node.args[0], node.args[1].id + return self._process_elements( + lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + def _process_elements( + self, + process_func: Callable[[itir.Expr], itir.Expr], + obj: foast.Expr, + current_el_type: ts.TypeSpec, + current_el_expr: itir.Expr = im.ref("expr"), + ): + """Recursively applies a processing function to all primitive constituents of a tuple.""" + if isinstance(current_el_type, ts.TupleType): + # TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element. + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements( + process_func, + obj, + current_el_type.types[i], + im.tuple_get(i, current_el_expr), + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") + else: + return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) + class FieldOperatorLoweringError(Exception): ... diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index d381a2242a..58181fd7a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -325,6 +325,76 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +@pytest.mark.uses_tuple_returns +def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures + @gtx.field_operator + def field_op_returning_a_tuple( + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]: + tup = (a, b) + return tup + + @gtx.field_operator + def cast_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype(field_op_returning_a_tuple(a, b), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1] == b_casted_to_int_outside_of_gt4py, + ) + + @gtx.field_operator + def cast_nested_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype((a, field_op_returning_a_tuple(a, b)), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1][0] == a_casted_to_int_outside_of_gt4py, + result[1][1] == b_casted_to_int_outside_of_gt4py, + ) + + a = cases.allocate(cartesian_case, cast_tuple, "a")() + b = cases.allocate(cartesian_case, cast_tuple, "b")() + a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) + b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() + out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() + + cases.verify( + cartesian_case, + cast_tuple, + a, + b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, + out=out_tuple, + ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), + ) + + cases.verify( + cartesian_case, + cast_nested_tuple, + a, + b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, + out=out_nested_tuple, + ref=( + np.full_like(a, True, dtype=bool), + np.full_like(a, True, dtype=bool), + np.full_like(b, True, dtype=bool), + ), + ) + + def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 7800a30e41..dfa710e038 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]): def test_astype_wrong_value_type(): def simple_astype(a: Field[[TDim], float64]): - # we just use a tuple here but anything that is not a field or scalar works - return astype((1, 2), bool) + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works + return astype(broadcast, bool) with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(simple_astype)