diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9fc1b42038..9db6535aeb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -204,12 +204,14 @@ def remap( def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) - new_buffer = self.ndarray[buffer_slice] - if len(new_domain) == 0: - # TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer - return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + if new_domain.ndim == 0: + buffer_val = self.ndarray[buffer_slice] + assert core_defs.is_scalar_type(buffer_val) + new_buffer = self._scalar_to_field(buffer_val) else: - return self.__class__.from_array(new_buffer, domain=new_domain) + new_buffer = self.ndarray[buffer_slice] + + return self.__class__.from_array(new_buffer, domain=new_domain) __getitem__ = restrict @@ -302,6 +304,12 @@ def _slice( assert common.is_relative_index_sequence(slice_) return new_domain, slice_ + def _scalar_to_field(self, value: core_defs.ScalarT) -> NdArrayField: + if self.array_ns == cp: + return cp.asarray(value) + else: + return np.asarray(value) + @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index cd75538da7..5845c2020e 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -196,6 +196,14 @@ def broadcast( ) +@BuiltInFunction +def as_scalar( + zero_d_field: common.Field, + /, +) -> common.Field: + return common.field(np.asarray(zero_d_field.ndarray), domain=common.Domain(dims=(), ranges=())) + + @WhereBuiltinFunction def where( mask: common.Field, @@ -301,6 +309,7 @@ def impl( "where", "astype", "as_offset", + "as_scalar", ] + MATH_BUILTIN_NAMES BUILTIN_NAMES = TYPE_BUILTIN_NAMES + FUN_BUILTIN_NAMES diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 5e289af664..0b6c19d072 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -931,6 +931,15 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: location=node.location, ) + def _visit_as_scalar(self, node: foast.Call, **kwargs) -> foast.Call: + return foast.Call( + func=node.func, + args=node.args, + kwargs=node.kwargs, + location=node.location, + type=node.args[0].type, + ) + def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 70c79d7b6c..c077a2d545 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -40,6 +40,7 @@ E2V, V2E, E2VDim, + Edge, IDim, Ioff, JDim, @@ -320,6 +321,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]: cases.verify(cartesian_case, testee, inp, out=out, ref=expected) +def test_as_scalar(cartesian_case): + @gtx.field_operator + def testee_fo(a: cases.IFloatField) -> cases.IFloatField: + return a + + @gtx.program + def testee_prog(a: cases.IFloatField): + testee_fo(a, out=a[1:2]) + + a = cases.allocate(cartesian_case, testee_prog, "a")() + ref = np.asarray(a.asnumpy()[1]) + + cases.verify(cartesian_case, testee_prog, a, inout=a[1], ref=ref) + + def test_astype_int(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: @@ -499,21 +515,16 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator - def testee(a: cases.EField) -> cases.EField: + def testee(a: gtx.Field[[Edge, KDim], int]) -> gtx.Field[[Vertex, KDim], int]: tmp = neighbor_sum(a(V2E), axis=V2EDim) - tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim) - return tmp_2 + # tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim) + return tmp cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table - ], - axis=1, - ), - comparison=lambda a, tmp_2: np.all(a == tmp_2), + ref=lambda a: np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1), + comparison=lambda a, tmp: np.all(a == tmp), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 2174871f89..7eec986386 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -22,6 +22,7 @@ DimensionKind, Field, FieldOffset, + as_scalar, astype, broadcast, common, @@ -618,6 +619,18 @@ def mismatched_lit() -> Field[[TDim], "float32"]: _ = FieldOperatorParser.apply_to_function(mismatched_lit) +def test_as_scalar(): + + def simple_as_scalar(a: Field[[], float64]): + return as_scalar(a) + + parsed = FieldOperatorParser.apply_to_function(simple_as_scalar) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + def test_broadcast_multi_dim(): ADim = Dimension("ADim") BDim = Dimension("BDim") diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 6863b09c12..8a4086f616 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -470,8 +470,8 @@ def test_absolute_indexing_value_return(): named_index = ((IDim, 12), (JDim, 6)) value = field[named_index] - assert isinstance(value, np.int32) - assert value == 21 + assert isinstance(value.asnumpy(), np.ndarray) + assert value.asnumpy() == np.asarray(21) @pytest.mark.parametrize( @@ -568,12 +568,15 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_value", - [((1, 0), 10), ((0, 1), 1)], + [ + ((1, 0), common.field(np.asarray(10), domain=common.Domain(dims=(), ranges=()))), + ((0, 1), common.field(np.asarray(1), domain=common.Domain(dims=(), ranges=()))), + ], ) def test_relative_indexing_value_return(index, expected_value): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common.field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) - indexed_field = field[index] + indexed_field = fbuiltins.as_scalar(field[index]) assert indexed_field == expected_value