From b22d96321a3db35b2310d00ec3591f48f9ce19d1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 12 Sep 2023 22:51:47 +0200 Subject: [PATCH] fix out-of-bounds access in column mode --- src/gt4py/next/iterator/embedded.py | 99 ++++++++++--------- .../iterator_tests/test_column_stencil.py | 38 ++++++- 2 files changed, 83 insertions(+), 54 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b9cf63eff9..b0e309f707 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -22,6 +22,7 @@ import dataclasses import itertools import math +import sys from typing import ( Any, Callable, @@ -604,6 +605,12 @@ class Undefined: def __float__(self): return np.nan + def __int__(self): + return sys.maxsize + + def __repr__(self): + return "_UNDEFINED" + @classmethod def _setup_math_operations(cls): ops = [ @@ -678,7 +685,8 @@ def _single_vertical_idx( indices: NamedFieldIndices, column_axis: Tag, column_index: common.IntIndex ) -> NamedFieldIndices: transformed = { - axis: (index if axis != column_axis else column_index) for axis, index in indices.items() + axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis` + for axis, index in indices.items() } return transformed @@ -725,55 +733,46 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Optional[Tag] = None, -) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...]: - column_range = column_range_cvar.get() - if isinstance(field_or_tuple, tuple): - if column_axis is not None: - assert column_range - # construct a Column of tuples - first = tuple( - _make_tuple(f, _single_vertical_idx(named_indices, column_axis, column_range.start)) - for f in field_or_tuple - ) - col = Column( - column_range.start, np.zeros(len(column_range), dtype=_column_dtype(first)) - ) - col[0] = first - for i in column_range[1:]: - col[i] = tuple( - _make_tuple(f, _single_vertical_idx(named_indices, column_axis, i)) - for f in field_or_tuple - ) - return col - else: +) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...] | Undefined: + if column_axis is None: + if isinstance(field_or_tuple, tuple): return tuple(_make_tuple(f, named_indices) for f in field_or_tuple) - else: - if column_axis is not None: - data = np.full((len(column_range),), np.nan) - rng = named_indices[column_axis] - - start = rng.start - start_target = 0 - stop = rng.stop - stop_target = len(column_range) - if rng.start < column_range.start: - start = column_range.start - start_target = column_range.start - rng.start - if rng.stop > column_range.stop: - stop = column_range.stop - stop_target = column_range.stop - rng.stop - named_indices[column_axis] = range(start, stop) - - data[start_target:stop_target] = field_or_tuple.field_getitem(named_indices) - # wraps a vertical slice of an input field into a `Column` - assert column_range is not None - return Column(column_range.start, data) else: try: data = field_or_tuple.field_getitem(named_indices) return data except embedded_exceptions.IndexOutOfBounds: - return np.nan # TODO what about non floats + return _UNDEFINED + else: + column_range = column_range_cvar.get() + assert column_range is not None + + col: list[npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...] | Undefined] = [] + for i in column_range: + # we don't know the buffer size, therefore we have to try. + try: + col.append( + tuple( + _make_tuple( + f, + _single_vertical_idx( + named_indices, column_axis, i - column_range.start + ), + ) + for f in field_or_tuple + ) + if isinstance(field_or_tuple, tuple) + else _make_tuple( + field_or_tuple, + _single_vertical_idx(named_indices, column_axis, i - column_range.start), + ) + ) + except embedded_exceptions.IndexOutOfBounds: + col.append(_UNDEFINED) + + first = next((v for v in col if v != _UNDEFINED)) + dtype = _column_dtype(first) + return Column(column_range.start, np.asarray(col, dtype=dtype)) @dataclasses.dataclass(frozen=True) @@ -890,7 +889,9 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): def __gt_dims__(self) -> tuple[common.Dimension, ...]: return self._ndarrayfield.__gt_dims__ - def _translate_named_indices(self, _named_indices: NamedFieldIndices) -> common.DomainSlice: + def _translate_named_indices( + self, _named_indices: NamedFieldIndices + ) -> common.AbsoluteIndexSequence: named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = { d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__ } @@ -1062,8 +1063,8 @@ def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.FieldSlice) -> common.Field | core_defs.int32: - if common.is_domain_slice(item) and all(common.is_named_index(e) for e in item): + def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32: + if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code d, r = item[0] assert d == self._dimension assert isinstance(r, int) @@ -1168,7 +1169,7 @@ def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.FieldSlice) -> common.Field | core_defs.ScalarT: + def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: # TODO set a domain... return self._value @@ -1362,7 +1363,7 @@ def shift(self, *offsets: OffsetPart) -> ScanArgIterator: def shifted_scan_arg(k_pos: int) -> Callable[[ItIterator], ScanArgIterator]: def impl(it: ItIterator) -> ScanArgIterator: - return ScanArgIterator(it, k_pos=k_pos) + return ScanArgIterator(it, k_pos=k_pos) # here we evaluate the full column in every step return impl diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index c73019cd4c..5211f2184d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -126,25 +126,53 @@ def k_level_condition_upper(k_idx, k_level): return if_(deref(k_idx) < deref(k_level), deref(shift(K, +1)(k_idx)), 0) +@fundef +def k_level_condition_upper_tuple(k_idx, k_level): + shifted_val = deref(shift(K, +1)(k_idx)) + return if_( + tuple_get(0, deref(k_idx)) < deref(k_level), + tuple_get(0, shifted_val) + tuple_get(1, shifted_val), + 0, + ) + + @pytest.mark.parametrize( - "fun, k_level, ref_function", + "fun, k_level, inp_function, ref_function", [ - (k_level_condition_lower, lambda inp: 0, lambda inp: np.concatenate([[0], inp[:-1]])), + ( + k_level_condition_lower, + lambda inp: 0, + lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + lambda inp: np.concatenate([[0], inp[:-1]]), + ), ( k_level_condition_upper, lambda inp: inp.shape[0] - 1, + lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([inp[1:], [0]]), ), + ( + k_level_condition_upper_tuple, + lambda inp: inp[0].shape[0] - 1, + lambda k_size: ( + gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + ), + lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), + ), ], ) -def test_k_level_condition(program_processor, lift_mode, fun, k_level, ref_function): +def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor + if program_processor == run_dace_iterator: + pytest.xfail("Not supported in DaCe backend: tuple arguments") + k_size = 5 - inp = gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)) + inp = inp_function(k_size) ref = ref_function(inp) - out = gtx.np_as_located_field(KDim)(np.zeros_like(inp)) + out = gtx.np_as_located_field(KDim)(np.zeros((5,), dtype=np.int32)) run_processor( fun[{KDim: range(0, k_size)}],