diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 6237704f69..a550db4f2e 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -444,6 +444,8 @@ def shape(self) -> tuple[int, ...]: ... @property def dtype(self) -> Any: ... + def item(self) -> Any: ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... def __getitem__(self, item: Any) -> NDArrayObject: ... diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 90e76d671d..fdf515d2f8 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -623,14 +623,17 @@ def asnumpy(self) -> np.ndarray: ... def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... + def restrict(self, item: AnyIndexSpec) -> Field: ... + + @abc.abstractmethod + def as_scalar(self) -> core_defs.ScalarT: ... # Operators @abc.abstractmethod def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... + def __getitem__(self, item: AnyIndexSpec) -> Field: ... @abc.abstractmethod def __abs__(self) -> Field: ... @@ -896,6 +899,9 @@ def ndarray(self) -> Never: def asnumpy(self) -> Never: raise NotImplementedError() + def as_scalar(self) -> Never: + raise NotImplementedError() + @functools.cached_property def domain(self) -> Domain: return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) @@ -947,9 +953,7 @@ def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Conne __call__ = remap - def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar: - if is_int_index(index): - return index + self.offset + def restrict(self, index: AnyIndexSpec) -> Never: raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case __getitem__ = restrict diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 65a71718e4..c39408ba3a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -120,6 +120,13 @@ def asnumpy(self) -> np.ndarray: else: return np.asarray(self._ndarray) + def as_scalar(self) -> core_defs.ScalarT: + if self.domain.ndim != 0: + raise ValueError( + "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + ) + return self.ndarray.item() + @property def codomain(self) -> type[core_defs.ScalarT]: return self.dtype.scalar_type @@ -204,15 +211,11 @@ def remap( __call__ = remap # type: ignore[assignment] - def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + def restrict(self, index: common.AnyIndexSpec) -> common.Field: new_domain, buffer_slice = self._slice(index) - new_buffer = self.ndarray[buffer_slice] - if len(new_domain) == 0: - # TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer - return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new_buffer, domain=new_domain) + new_buffer = self.__class__.array_ns.asarray(new_buffer) + return self.__class__.from_array(new_buffer, domain=new_domain) __getitem__ = restrict @@ -433,7 +436,7 @@ def inverse_image( return new_dims - def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar: + def restrict(self, index: common.AnyIndexSpec) -> common.Field: cache_key = (id(self.ndarray), self.domain, index) if (restricted_connectivity := self._cache.get(cache_key, None)) is None: diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index cb03373b41..fc3ccda335 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -187,8 +187,7 @@ def _tuple_at( ) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: @utils.tree_map def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: - res = field[pos] if common.is_field(field) else field - res = res.item() if hasattr(res, "item") else res # extract scalar value from array + res = field[pos].as_scalar() if common.is_field(field) else field assert core_defs.is_scalar_type(res) return res diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 0fd263308e..0831fc3bb2 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -289,7 +289,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: index = self._match_index(node.slice) except ValueError: raise errors.DSLError( - self.get_location(node.slice), "eXpected an integral index." + self.get_location(node.slice), "Expected an integral index." ) from None return foast.Subscript( diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index af8f5e8368..0e5be1eabd 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -217,12 +217,12 @@ def visit_Call(self, node: past.Call, **kwargs): f"'{new_kwargs['out'].type}'." ) elif new_func.id in ["minimum", "maximum"]: - if new_args[0].type != new_args[1].type: + if arg_types[0] != arg_types[1]: raise ValueError( f"First and second argument in '{new_func.id}' must be of the same type." - f"Got '{new_args[0].type}' and '{new_args[1].type}'." + f"Got '{arg_types[0]}' and '{arg_types[1]}'." ) - return_type = new_args[0].type + return_type = arg_types[0] else: raise AssertionError( "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index ed239e0436..620e98dd4d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -305,7 +305,7 @@ def _visit_stencil_call_out_arg( ) -> tuple[itir.Expr, itir.FunCall]: if isinstance(out_arg, past.Subscript): # as the ITIR does not support slicing a field we have to do a deeper - # inspection of the PAST to emulate the behaviour + # inspection of the PAST to emulate the behaviour out_field_name: past.Name = out_arg.value return ( self._construct_itir_out_arg(out_field_name), @@ -382,12 +382,11 @@ def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall: ) def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: - if node.func.id in ["maximum", "minimum"] and len(node.args) == 2: + if node.func.id in ["maximum", "minimum"]: + assert len(node.args) == 2 return itir.FunCall( fun=itir.SymRef(id=node.func.id), args=[self.visit(node.args[0]), self.visit(node.args[1])], ) else: - raise AssertionError( - "Only 'minimum' and 'maximum' builtins supported supported currently." - ) + raise NotImplementedError("Only 'minimum', and 'maximum' builtins supported currently.") diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 011ca4d92b..a45b81a773 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -919,7 +919,7 @@ def _translate_named_indices( return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: - return self._ndarrayfield[self._translate_named_indices(named_indices)] + return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar() def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): @@ -1040,6 +1040,7 @@ class IndexField(common.Field): """ _dimension: common.Dimension + _cur_index: Optional[core_defs.IntegralScalar] = None @property def __gt_domain__(self) -> common.Domain: @@ -1055,7 +1056,10 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: - return common.Domain((self._dimension, common.UnitRange.infinite())) + if self._cur_index is None: + return common.Domain((self._dimension, common.UnitRange.infinite())) + else: + return common.Domain() @property def codomain(self) -> type[core_defs.int32]: @@ -1072,16 +1076,24 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() + def as_scalar(self) -> core_defs.IntegralScalar: + if self.domain.ndim != 0: + raise ValueError( + "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + ) + assert self._cur_index is not None + return self._cur_index + def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32: + def restrict(self, item: common.AnyIndexSpec) -> common.Field: if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off d, r = item[0] assert d == self._dimension - assert isinstance(r, int) - return self.dtype.scalar_type(r) + assert isinstance(r, core_defs.INTEGRAL_TYPES) + return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work # TODO set a domain... raise NotImplementedError() @@ -1195,8 +1207,12 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + def restrict(self, item: common.AnyIndexSpec) -> common.Field: # TODO set a domain... + return self + + def as_scalar(self) -> core_defs.ScalarT: + assert self.domain.ndim == 0 return self._value __call__ = remap diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7fc2d82e67..4d6168d446 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -321,6 +321,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]: cases.verify(cartesian_case, testee, inp, out=out, ref=expected) +def test_single_value_field(cartesian_case): + @gtx.field_operator + def testee_fo(a: cases.IKField) -> cases.IKField: + return a + + @gtx.program + def testee_prog(a: cases.IKField): + testee_fo(a, out=a[1:2, 3:4]) + + a = cases.allocate(cartesian_case, testee_prog, "a")() + ref = a[1, 3] + + cases.verify(cartesian_case, testee_prog, a, inout=a[1, 3], ref=ref) + + def test_astype_int(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index 69f594a2bc..e4540ba1b9 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -70,7 +70,7 @@ def test_simple_indirection(program_processor): ref = np.zeros(shape, dtype=inp.dtype) for i in range(shape[0]): - ref[i] = inp.ndarray[i + 1 - 1] if cond[i] < 0.0 else inp.ndarray[i + 1 + 1] + ref[i] = inp.asnumpy()[i + 1 - 1] if cond.asnumpy()[i] < 0.0 else inp.asnumpy()[i + 1 + 1] run_processor( conditional_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))], @@ -101,7 +101,7 @@ def test_direct_offset_for_indirection(program_processor): ref = np.zeros(shape) for i in range(shape[0]): - ref[i] = inp[i + cond[i]] + ref[i] = inp.asnumpy()[i + cond.asnumpy()[i]] run_processor( direct_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))], diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 806ab7eb9a..9a1bc6deb6 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -65,11 +65,15 @@ def fencil(x, y, z, out, inp): def naive_lap(inp): shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]] out = np.zeros(shape) + inp_data = inp.asnumpy() for i in range(1, shape[0] + 1): for j in range(1, shape[1] + 1): for k in range(0, shape[2]): - out[i - 1, j - 1, k] = -4 * inp[i, j, k] + ( - inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k] + out[i - 1, j - 1, k] = -4 * inp_data[i, j, k] + ( + inp_data[i + 1, j, k] + + inp_data[i - 1, j, k] + + inp_data[i, j + 1, k] + + inp_data[i, j - 1, k] ) return out diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 70fa274457..79830a75a1 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -468,10 +468,11 @@ def test_absolute_indexing_value_return(): field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) named_index = ((IDim, 12), (JDim, 6)) + assert common.is_field(field) value = field[named_index] - assert isinstance(value, np.int32) - assert value == 21 + assert common.is_field(value) + assert value.as_scalar() == 21 @pytest.mark.parametrize( @@ -568,14 +569,17 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_value", - [((1, 0), 10), ((0, 1), 1)], + [ + ((1, 0), 10), + ((0, 1), 1), + ], ) def test_relative_indexing_value_return(index, expected_value): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) indexed_field = field[index] - assert indexed_field == expected_value + assert indexed_field.as_scalar() == expected_value @pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]])