Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Extend astype to work with tuples #1352

Merged
merged 19 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ def where(


@builtin_function
def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field:
def astype(
field: Field | gt4py_defs.ScalarT | Tuple[Field, ...],
type_: type,
/,
) -> Field | Tuple[Field, ...]:
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError()


Expand Down
19 changes: 15 additions & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType)
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
) # already checked using generic mechanism

if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [
kind.name for kind in ts.ScalarKind
]:
Expand All @@ -835,9 +837,18 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.",
)

return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
)
if isinstance(value, foast.TupleExpr):
ninaburg marked this conversation as resolved.
Show resolved Hide resolved
return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
)

else:
return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
)
ninaburg marked this conversation as resolved.
Show resolved Hide resolved

return foast.Call(
func=node.func,
Expand Down
21 changes: 16 additions & 5 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,22 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, dtype = node.args[0], node.args[1].id

# TODO check that we test astype that results in a itir.map_ operation
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
)
def recursive_cast(obj, dtype):
if isinstance(obj, foast.TupleExpr):
casted_elements = []

for element in obj.elts:
casted_element = recursive_cast(element, dtype)
casted_elements.append(casted_element)

return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
*casted_elements
)

else:
return self._map(im.lambda_("it")(im.call("cast_")("it", str(dtype))), obj)

return recursive_cast(obj, dtype)

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
return self._map("if_", *node.args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,36 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
)


def test_astype_on_tuples(cartesian_case):
@gtx.field_operator
def cast_tuple(
a: cases.IFloatField, b: cases.IFloatField
) -> tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]:
return astype((a, b), int64)

@gtx.field_operator
def combine(a: cases.IFloatField, b: cases.IFloatField) -> gtx.Field[[IDim], int64]:
packed_tuple = cast_tuple(a, b)
return packed_tuple[0] + packed_tuple[1]
ninaburg marked this conversation as resolved.
Show resolved Hide resolved

cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + b)


def test_astype_on_nested_tuples(cartesian_case):
ninaburg marked this conversation as resolved.
Show resolved Hide resolved
@gtx.field_operator
def cast_nested_tuple(
a: cases.IField, b: cases.IField
) -> tuple[gtx.Field[[IDim], int64], tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]]:
return astype((a, (a, b)), int64)

@gtx.field_operator
def combine(a: cases.IField, b: cases.IField) -> gtx.Field[[IDim], int64]:
nested_tuple = cast_nested_tuple(a, b)
return nested_tuple[0] + nested_tuple[1][0] + nested_tuple[1][1]

cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b)


def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]):

def test_astype_wrong_value_type():
def simple_astype(a: Field[[TDim], float64]):
# we just use a tuple here but anything that is not a field or scalar works
return astype((1, 2), bool)
# we just use broadcast here but anything that is not a field, scalar or tuple thereof works
return astype(broadcast, bool)

with pytest.raises(errors.DSLError) as exc_info:
_ = FieldOperatorParser.apply_to_function(simple_astype)
Expand Down