Skip to content

Commit

Permalink
Extend astype() for tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
Nina Burgdorfer authored and Nina Burgdorfer committed Oct 19, 2023
1 parent f96ead5 commit 048d4c0
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 11 deletions.
6 changes: 5 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ def where(


@builtin_function
def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field:
def astype(
field: Field | gt4py_defs.ScalarT | Tuple[Field, ...],
type_: type,
/,
) -> Field | Tuple[Field, ...]:
raise NotImplementedError()


Expand Down
21 changes: 17 additions & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType)
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
) # already checked using generic mechanism

if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [
kind.name for kind in ts.ScalarKind
]:
Expand All @@ -835,9 +837,20 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.",
)

return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
)
if isinstance(value, foast.TupleExpr):
element_types_new = []
for element in value.elts:
element_types_new.append(
with_altered_scalar_kind(
element.type, getattr(ts.ScalarKind, new_type.id.upper())
)
)
return_type = ts.TupleType(types=cast(list[ts.DataType], element_types_new))

else:
return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
)

return foast.Call(
func=node.func,
Expand Down
22 changes: 16 additions & 6 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,22 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, dtype = node.args[0], node.args[1].id

# TODO check that we test astype that results in a itir.map_ operation
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
)
if isinstance(obj, foast.TupleExpr):
casted_elements = []
for _, element in enumerate(obj.elts):
casted_element = self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))), element
)
casted_elements.append(casted_element)
args = [f"__arg{i}" for i in range(len(casted_elements))]
return im.lift(im.lambda_(*args)(im.make_tuple(*[im.deref(arg) for arg in args])))(
*casted_elements
)
else:
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
)

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
return self._map("if_", *node.args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
import pytest

import gt4py.next as gtx
from gt4py.eve import SymbolRef
from gt4py.next import float32, float64
from gt4py.next.ffront.fbuiltins import astype
from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering
from gt4py.next.ffront.func_to_foast import FieldOperatorParser
from gt4py.next.iterator import ir as itir, ir_makers as im
from gt4py.next.type_system import type_specifications as ts


TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests.
Expand All @@ -42,3 +46,22 @@ def fieldop_with_typealias(
foast_tree.body.stmts[0].value.left.func.id == expected
and foast_tree.body.stmts[0].value.right.args[1].id == expected
)


def test_type_alias_replacement_astype_with_tuples():
def fieldop_with_typealias_with_tuples(
a: gtx.Field[[TDim], vpfloat], b: gtx.Field[[TDim], vpfloat]
) -> tuple[gtx.Field[[TDim], wpfloat], gtx.Field[[TDim], wpfloat]]:
return astype((a, b), wpfloat)

parsed = FieldOperatorParser.apply_to_function(fieldop_with_typealias_with_tuples)
lowered = FieldOperatorLowering.apply(parsed)

# Check that the type of the first arg of "astype" is a tuple
assert isinstance(parsed.body.stmts[0].value.args[0].type, ts.TupleType)
# Check that the return type of "astype" is a tuple
assert isinstance(parsed.body.stmts[0].value.type, ts.TupleType)
# Check inside the lift function that make_tuple is applied to return a tuple
assert lowered.expr.fun.args[0].expr.fun == itir.SymRef(id=SymbolRef("make_tuple"))
# Check that the elements that form the tuple called the cast_ function individually
assert lowered.expr.args[0].fun.args[0].expr.fun.expr.fun == itir.SymRef(id=SymbolRef("cast_"))

0 comments on commit 048d4c0

Please sign in to comment.