diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 52aae34b3f..ea79f3d8fd 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -175,7 +175,11 @@ def where( @builtin_function -def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: +def astype( + field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + type_: type, + /, +) -> Field | Tuple[Field, ...]: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 605b83a5f0..903c871e13 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: + return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( - value.type, (ts.FieldType, ts.ScalarType) + value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ kind.name for kind in ts.ScalarKind ]: @@ -835,9 +837,20 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) - ) + if isinstance(value, foast.TupleExpr): + element_types_new = [] + for element in value.elts: + element_types_new.append( + with_altered_scalar_kind( + element.type, getattr(ts.ScalarKind, new_type.id.upper()) + ) + ) + return_type = ts.TupleType(types=cast(list[ts.DataType], element_types_new)) + + else: + return_type = with_altered_scalar_kind( + value.type, getattr(ts.ScalarKind, new_type.id.upper()) + ) return foast.Call( func=node.func, diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..3e6d5b3c4e 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -318,12 +318,22 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, dtype = node.args[0], node.args[1].id - - # TODO check that we test astype that results in a itir.map_ operation - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, - ) + if isinstance(obj, foast.TupleExpr): + casted_elements = [] + for _, element in enumerate(obj.elts): + casted_element = self._map( + im.lambda_("it")(im.call("cast_")("it", str(dtype))), element + ) + casted_elements.append(casted_element) + args = [f"__arg{i}" for i in range(len(casted_elements))] + return im.lift(im.lambda_(*args)(im.make_tuple(*[im.deref(arg) for arg in args])))( + *casted_elements + ) + else: + return self._map( + im.lambda_("it")(im.call("cast_")("it", str(dtype))), + obj, + ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args) diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py index e87f869352..3be72bdb33 100644 --- a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -19,9 +19,13 @@ import pytest import gt4py.next as gtx +from gt4py.eve import SymbolRef from gt4py.next import float32, float64 from gt4py.next.ffront.fbuiltins import astype +from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.type_system import type_specifications as ts TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. @@ -42,3 +46,22 @@ def fieldop_with_typealias( foast_tree.body.stmts[0].value.left.func.id == expected and foast_tree.body.stmts[0].value.right.args[1].id == expected ) + + +def test_type_alias_replacement_astype_with_tuples(): + def fieldop_with_typealias_with_tuples( + a: gtx.Field[[TDim], vpfloat], b: gtx.Field[[TDim], vpfloat] + ) -> tuple[gtx.Field[[TDim], wpfloat], gtx.Field[[TDim], wpfloat]]: + return astype((a, b), wpfloat) + + parsed = FieldOperatorParser.apply_to_function(fieldop_with_typealias_with_tuples) + lowered = FieldOperatorLowering.apply(parsed) + + # Check that the type of the first arg of "astype" is a tuple + assert isinstance(parsed.body.stmts[0].value.args[0].type, ts.TupleType) + # Check that the return type of "astype" is a tuple + assert isinstance(parsed.body.stmts[0].value.type, ts.TupleType) + # Check inside the lift function that make_tuple is applied to return a tuple + assert lowered.expr.fun.args[0].expr.fun == itir.SymRef(id=SymbolRef("make_tuple")) + # Check that the elements that form the tuple called the cast_ function individually + assert lowered.expr.args[0].fun.args[0].expr.fun.expr.fun == itir.SymRef(id=SymbolRef("cast_"))