Skip to content

Commit

Permalink
fix out-of-bounds access in column mode
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Sep 12, 2023
1 parent 2e9eec0 commit b22d963
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 54 deletions.
99 changes: 50 additions & 49 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import dataclasses
import itertools
import math
import sys
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -604,6 +605,12 @@ class Undefined:
def __float__(self):
return np.nan

def __int__(self):
return sys.maxsize

def __repr__(self):
return "_UNDEFINED"

@classmethod
def _setup_math_operations(cls):
ops = [
Expand Down Expand Up @@ -678,7 +685,8 @@ def _single_vertical_idx(
indices: NamedFieldIndices, column_axis: Tag, column_index: common.IntIndex
) -> NamedFieldIndices:
transformed = {
axis: (index if axis != column_axis else column_index) for axis, index in indices.items()
axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis`
for axis, index in indices.items()
}
return transformed

Expand Down Expand Up @@ -725,55 +733,46 @@ def _make_tuple(
named_indices: NamedFieldIndices,
*,
column_axis: Optional[Tag] = None,
) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...]:
column_range = column_range_cvar.get()
if isinstance(field_or_tuple, tuple):
if column_axis is not None:
assert column_range
# construct a Column of tuples
first = tuple(
_make_tuple(f, _single_vertical_idx(named_indices, column_axis, column_range.start))
for f in field_or_tuple
)
col = Column(
column_range.start, np.zeros(len(column_range), dtype=_column_dtype(first))
)
col[0] = first
for i in column_range[1:]:
col[i] = tuple(
_make_tuple(f, _single_vertical_idx(named_indices, column_axis, i))
for f in field_or_tuple
)
return col
else:
) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...] | Undefined:
if column_axis is None:
if isinstance(field_or_tuple, tuple):
return tuple(_make_tuple(f, named_indices) for f in field_or_tuple)
else:
if column_axis is not None:
data = np.full((len(column_range),), np.nan)
rng = named_indices[column_axis]

start = rng.start
start_target = 0
stop = rng.stop
stop_target = len(column_range)
if rng.start < column_range.start:
start = column_range.start
start_target = column_range.start - rng.start
if rng.stop > column_range.stop:
stop = column_range.stop
stop_target = column_range.stop - rng.stop
named_indices[column_axis] = range(start, stop)

data[start_target:stop_target] = field_or_tuple.field_getitem(named_indices)
# wraps a vertical slice of an input field into a `Column`
assert column_range is not None
return Column(column_range.start, data)
else:
try:
data = field_or_tuple.field_getitem(named_indices)
return data
except embedded_exceptions.IndexOutOfBounds:
return np.nan # TODO what about non floats
return _UNDEFINED
else:
column_range = column_range_cvar.get()
assert column_range is not None

col: list[npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...] | Undefined] = []
for i in column_range:
# we don't know the buffer size, therefore we have to try.
try:
col.append(
tuple(
_make_tuple(
f,
_single_vertical_idx(
named_indices, column_axis, i - column_range.start
),
)
for f in field_or_tuple
)
if isinstance(field_or_tuple, tuple)
else _make_tuple(
field_or_tuple,
_single_vertical_idx(named_indices, column_axis, i - column_range.start),
)
)
except embedded_exceptions.IndexOutOfBounds:
col.append(_UNDEFINED)

first = next((v for v in col if v != _UNDEFINED))
dtype = _column_dtype(first)
return Column(column_range.start, np.asarray(col, dtype=dtype))


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -890,7 +889,9 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField):
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return self._ndarrayfield.__gt_dims__

def _translate_named_indices(self, _named_indices: NamedFieldIndices) -> common.DomainSlice:
def _translate_named_indices(
self, _named_indices: NamedFieldIndices
) -> common.AbsoluteIndexSequence:
named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = {
d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__
}
Expand Down Expand Up @@ -1062,8 +1063,8 @@ def remap(self, index_field: common.Field) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.FieldSlice) -> common.Field | core_defs.int32:
if common.is_domain_slice(item) and all(common.is_named_index(e) for e in item):
def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32:
if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code
d, r = item[0]
assert d == self._dimension
assert isinstance(r, int)
Expand Down Expand Up @@ -1168,7 +1169,7 @@ def remap(self, index_field: common.Field) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.FieldSlice) -> common.Field | core_defs.ScalarT:
def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
# TODO set a domain...
return self._value

Expand Down Expand Up @@ -1362,7 +1363,7 @@ def shift(self, *offsets: OffsetPart) -> ScanArgIterator:

def shifted_scan_arg(k_pos: int) -> Callable[[ItIterator], ScanArgIterator]:
def impl(it: ItIterator) -> ScanArgIterator:
return ScanArgIterator(it, k_pos=k_pos)
return ScanArgIterator(it, k_pos=k_pos) # here we evaluate the full column in every step

return impl

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,53 @@ def k_level_condition_upper(k_idx, k_level):
return if_(deref(k_idx) < deref(k_level), deref(shift(K, +1)(k_idx)), 0)


@fundef
def k_level_condition_upper_tuple(k_idx, k_level):
shifted_val = deref(shift(K, +1)(k_idx))
return if_(
tuple_get(0, deref(k_idx)) < deref(k_level),
tuple_get(0, shifted_val) + tuple_get(1, shifted_val),
0,
)


@pytest.mark.parametrize(
"fun, k_level, ref_function",
"fun, k_level, inp_function, ref_function",
[
(k_level_condition_lower, lambda inp: 0, lambda inp: np.concatenate([[0], inp[:-1]])),
(
k_level_condition_lower,
lambda inp: 0,
lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)),
lambda inp: np.concatenate([[0], inp[:-1]]),
),
(
k_level_condition_upper,
lambda inp: inp.shape[0] - 1,
lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)),
lambda inp: np.concatenate([inp[1:], [0]]),
),
(
k_level_condition_upper_tuple,
lambda inp: inp[0].shape[0] - 1,
lambda k_size: (
gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)),
gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)),
),
lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]),
),
],
)
def test_k_level_condition(program_processor, lift_mode, fun, k_level, ref_function):
def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function):
program_processor, validate = program_processor

if program_processor == run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: tuple arguments")

k_size = 5
inp = gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32))
inp = inp_function(k_size)
ref = ref_function(inp)

out = gtx.np_as_located_field(KDim)(np.zeros_like(inp))
out = gtx.np_as_located_field(KDim)(np.zeros((5,), dtype=np.int32))

run_processor(
fun[{KDim: range(0, k_size)}],
Expand Down

0 comments on commit b22d963

Please sign in to comment.