Skip to content

Commit

Permalink
feat[next]: Extend astype to work with tuples (#1352)
Browse files Browse the repository at this point in the history
* Extend astype() for tuples

* Adapt existing test for arg types of astype()

* Adress requested style change

* Add extra type check

* Use apply_to_primitive_constituents function on (nested) tuples

* Adress 'nitpicking' change

* Remove previous test and add integration test for casting (nested) tuples

* Adapt visit_astype method with recursive func for nested tuples

* Fix integration test

* Call 'with_altered_scalar_kind' only once

* Recursive 'process_elements' func to apply a func on the elts of a tuple

* Fix execution tests

* Adapt visit_astype for foast.Call and foast.Name

* Fix tests

* Rename args and refactor 'process_elements'

* Fix tests

---------

Co-authored-by: Nina Burgdorfer <[email protected]>
  • Loading branch information
ninaburg and Nina Burgdorfer authored Nov 17, 2023
1 parent da1da20 commit 67a6188
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 12 deletions.
6 changes: 5 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ def where(


@builtin_function
def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field:
def astype(
field: Field | gt4py_defs.ScalarT | Tuple[Field, ...],
type_: type,
/,
) -> Field | Tuple[Field, ...]:
raise NotImplementedError()


Expand Down
11 changes: 8 additions & 3 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType)
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
) # already checked using generic mechanism

if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [
kind.name for kind in ts.ScalarKind
]:
Expand All @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.",
)

return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
)

return foast.Call(
Expand Down
35 changes: 29 additions & 6 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:

def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, dtype = node.args[0], node.args[1].id

# TODO check that we test astype that results in a itir.map_ operation
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
obj, new_type = node.args[0], node.args[1].id
return self._process_elements(
lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs
)

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
Expand Down Expand Up @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs):

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)

def _process_elements(
self,
process_func: Callable[[itir.Expr], itir.Expr],
obj: foast.Expr,
current_el_type: ts.TypeSpec,
current_el_expr: itir.Expr = im.ref("expr"),
):
"""Recursively applies a processing function to all primitive constituents of a tuple."""
if isinstance(current_el_type, ts.TupleType):
# TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element.
return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
*[
self._process_elements(
process_func,
obj,
current_el_type.types[i],
im.tuple_get(i, current_el_expr),
)
for i in range(len(current_el_type.types))
]
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj)


class FieldOperatorLoweringError(Exception):
...
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,76 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
)


@pytest.mark.uses_tuple_returns
def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def field_op_returning_a_tuple(
a: cases.IFloatField, b: cases.IFloatField
) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]:
tup = (a, b)
return tup

@gtx.field_operator
def cast_tuple(
a: cases.IFloatField,
b: cases.IFloatField,
a_casted_to_int_outside_of_gt4py: cases.IField,
b_casted_to_int_outside_of_gt4py: cases.IField,
) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]:
result = astype(field_op_returning_a_tuple(a, b), int32)
return (
result[0] == a_casted_to_int_outside_of_gt4py,
result[1] == b_casted_to_int_outside_of_gt4py,
)

@gtx.field_operator
def cast_nested_tuple(
a: cases.IFloatField,
b: cases.IFloatField,
a_casted_to_int_outside_of_gt4py: cases.IField,
b_casted_to_int_outside_of_gt4py: cases.IField,
) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]:
result = astype((a, field_op_returning_a_tuple(a, b)), int32)
return (
result[0] == a_casted_to_int_outside_of_gt4py,
result[1][0] == a_casted_to_int_outside_of_gt4py,
result[1][1] == b_casted_to_int_outside_of_gt4py,
)

a = cases.allocate(cartesian_case, cast_tuple, "a")()
b = cases.allocate(cartesian_case, cast_tuple, "b")()
a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32))
b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32))
out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)()
out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)()

cases.verify(
cartesian_case,
cast_tuple,
a,
b,
a_casted_to_int_outside_of_gt4py,
b_casted_to_int_outside_of_gt4py,
out=out_tuple,
ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)),
)

cases.verify(
cartesian_case,
cast_nested_tuple,
a,
b,
a_casted_to_int_outside_of_gt4py,
b_casted_to_int_outside_of_gt4py,
out=out_nested_tuple,
ref=(
np.full_like(a, True, dtype=bool),
np.full_like(a, True, dtype=bool),
np.full_like(b, True, dtype=bool),
),
)


def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]):

def test_astype_wrong_value_type():
def simple_astype(a: Field[[TDim], float64]):
# we just use a tuple here but anything that is not a field or scalar works
return astype((1, 2), bool)
# we just use broadcast here but anything that is not a field, scalar or tuple thereof works
return astype(broadcast, bool)

with pytest.raises(errors.DSLError) as exc_info:
_ = FieldOperatorParser.apply_to_function(simple_astype)
Expand Down

0 comments on commit 67a6188

Please sign in to comment.