Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Slicing field to 0d to return field not scalar #1427

Merged
merged 62 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
7292b2e
edit to return array from slicing instead of scalar
nfarabullini Jan 25, 2024
75c96ea
removed `as_scalar` builtin
nfarabullini Jan 29, 2024
92f92d1
removed `as_scalar` builtin from test
nfarabullini Jan 29, 2024
4a19cf9
further cleanup
nfarabullini Jan 29, 2024
1e6d90f
further cleanup and fixed pre-commit
nfarabullini Jan 29, 2024
5034f63
removed small thing
nfarabullini Jan 29, 2024
8f64d19
ran pre-commit
nfarabullini Jan 29, 2024
b3f7292
edit to restrict
nfarabullini Jan 29, 2024
1ef427d
edit to restrict and test
nfarabullini Jan 29, 2024
75339ba
small edit for scalar index
nfarabullini Jan 29, 2024
1de57a1
small change in nomenclature
nfarabullini Jan 29, 2024
41bdf23
edit to test file
nfarabullini Jan 29, 2024
bed69b4
fixed tests and ran pre-commit
nfarabullini Jan 30, 2024
20970ab
inverted if condition to avoid negation
nfarabullini Jan 30, 2024
d6b3d93
Merge branch 'main' of https://github.com/nfarabullini/gt4py into zer…
nfarabullini Jan 30, 2024
687e724
edit to test
nfarabullini Jan 30, 2024
621aab0
Update test_horizontal_indirection.py
nfarabullini Jan 30, 2024
4d56367
ran pre-commit
nfarabullini Jan 30, 2024
0c5ce05
removed if condition for specific dims
nfarabullini Jan 31, 2024
e1de878
edits
nfarabullini Jan 31, 2024
ff1d87e
edit for empty field
nfarabullini Jan 31, 2024
ab0a830
Update test_horizontal_indirection.py
nfarabullini Jan 31, 2024
e412aed
edits to tests and typehint
nfarabullini Jan 31, 2024
7fb99e2
Merge branch 'zero_d_return' of https://github.com/nfarabullini/gt4py…
nfarabullini Jan 31, 2024
7d4e8c3
modified assertion
nfarabullini Jan 31, 2024
91ab9e3
modified assertion
nfarabullini Jan 31, 2024
29c1c34
minor edit
nfarabullini Jan 31, 2024
727fadb
edits to embedded and tests
nfarabullini Feb 1, 2024
a7c5393
edits to test
nfarabullini Feb 1, 2024
9ed035e
edits to test
nfarabullini Feb 1, 2024
8feeee3
edits to test
nfarabullini Feb 1, 2024
f6bea1e
edits to operators.py
nfarabullini Feb 1, 2024
ea48fae
ran pre-commit
nfarabullini Feb 1, 2024
9ad29de
small edit
nfarabullini Feb 1, 2024
064654e
Update embedded.py
nfarabullini Feb 1, 2024
8034a59
Update operators.py
nfarabullini Feb 1, 2024
32fcdf8
Update src/gt4py/next/ffront/past_to_itir.py
nfarabullini Feb 2, 2024
93de32d
Update tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
nfarabullini Feb 2, 2024
aa8568e
small edit
nfarabullini Feb 2, 2024
1acb6fb
Update test_execution.py
nfarabullini Feb 2, 2024
6a47b53
Update tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
nfarabullini Feb 5, 2024
c1a4be3
Update tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
nfarabullini Feb 5, 2024
6c3a28a
Update tests/next_tests/integration_tests/feature_tests/iterator_test…
nfarabullini Feb 5, 2024
bc7e1bc
Update tests/next_tests/integration_tests/feature_tests/iterator_test…
nfarabullini Feb 5, 2024
3cfa2af
edits following review
nfarabullini Feb 5, 2024
71ea725
edit to test
nfarabullini Feb 5, 2024
e70edcf
re-introduced _scalar_to_field function
nfarabullini Feb 6, 2024
fd96f26
edit to embedded for _scalar_to_field function
nfarabullini Feb 6, 2024
76cbf20
edits for _scalar_to_field function
nfarabullini Feb 6, 2024
ebb5045
edits to nd_array_field tests
nfarabullini Feb 6, 2024
9308e18
edits to embedded restricts
nfarabullini Feb 6, 2024
23eb8df
small pre-commit edit
nfarabullini Feb 6, 2024
1d50c00
edits following review
nfarabullini Feb 7, 2024
39ec79f
Merge remote-tracking branch 'upstream/main' into zero_d_return
havogt Feb 8, 2024
ccdebc2
edit to embedded
nfarabullini Feb 8, 2024
b50305b
changes to itir as_scalar
havogt Feb 8, 2024
3f68e55
Merge branch 'zero_d_return' of https://github.com/nfarabullini/gt4py…
havogt Feb 8, 2024
99609fa
edit to embedded
nfarabullini Feb 8, 2024
5f59b22
merge with upstream
nfarabullini Feb 8, 2024
5126919
ran pre-commit
nfarabullini Feb 8, 2024
f15783c
fix for ndarray test
nfarabullini Feb 8, 2024
55cf242
disallow slicing CartesianConnectivity
havogt Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,13 @@ def remap(

__call__ = remap # type: ignore[assignment]

def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
def restrict(self, index: common.AnyIndexSpec) -> common.Field:
havogt marked this conversation as resolved.
Show resolved Hide resolved
new_domain, buffer_slice = self._slice(index)

new_buffer = self.ndarray[buffer_slice]
if len(new_domain) == 0:
# TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer
return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here
else:
return self.__class__.from_array(new_buffer, domain=new_domain)
if new_domain.ndim == 0:
assert self.array_ns.dtype(new_buffer) in core_defs.SCALAR_TYPES
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
return self._scalar_to_field(new_buffer) # type: ignore[return-value, arg-type]
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
return self.__class__.from_array(new_buffer, domain=new_domain)
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved

__getitem__ = restrict

Expand Down Expand Up @@ -304,6 +302,12 @@ def _slice(
assert common.is_relative_index_sequence(slice_)
return new_domain, slice_

def _scalar_to_field(self, value: core_defs.Scalar) -> np.ndarray:
if self.array_ns == cp:
return cp.asarray(value)
else:
return np.asarray(value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, see comment above.

Suggested change
def _scalar_to_field(self, value: core_defs.Scalar) -> np.ndarray:
if self.array_ns == cp:
return cp.asarray(value)
else:
return np.asarray(value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be removed already?


@dataclasses.dataclass(frozen=True)
class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__
Expand Down Expand Up @@ -429,7 +433,7 @@ def inverse_image(

return new_dims

def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar:
def restrict(self, index: common.AnyIndexSpec) -> common.Field:
cache_key = (id(self.ndarray), self.domain, index)

if (restricted_connectivity := self._cache.get(cache_key, None)) is None:
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def _tuple_at(
) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]:
@utils.tree_map
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
res = field[pos] if common.is_field(field) else field
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
res = field[pos].item() if common.is_field(field) else field # type: ignore[union-attr]
assert core_defs.is_scalar_type(res)
return res

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript:
index = self._match_index(node.slice)
except ValueError:
raise errors.DSLError(
self.get_location(node.slice), "eXpected an integral index."
self.get_location(node.slice), "Expected an integral index."
) from None

return foast.Subscript(
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def visit_Call(self, node: past.Call, **kwargs):
f"'{new_kwargs['out'].type}'."
)
elif new_func.id in ["minimum", "maximum"]:
if new_args[0].type != new_args[1].type:
if arg_types[0] != arg_types[1]:
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"First and second argument in '{new_func.id}' must be of the same type."
f"Got '{new_args[0].type}' and '{new_args[1].type}'."
f"Got '{arg_types[0]}' and '{arg_types[1]}'."
)
return_type = new_args[0].type
return_type = arg_types[0]
else:
raise AssertionError(
"Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed."
Expand Down
9 changes: 4 additions & 5 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _visit_stencil_call_out_arg(
) -> tuple[itir.Expr, itir.FunCall]:
if isinstance(out_arg, past.Subscript):
# as the ITIR does not support slicing a field we have to do a deeper
# inspection of the PAST to emulate the behaviour
# inspection of the PAST to emulate the behaviour
out_field_name: past.Name = out_arg.value
return (
self._construct_itir_out_arg(out_field_name),
Expand Down Expand Up @@ -382,12 +382,11 @@ def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall:
)

def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall:
if node.func.id in ["maximum", "minimum"] and len(node.args) == 2:
if node.func.id in ["maximum", "minimum"]:
assert len(node.args) == 2
return itir.FunCall(
fun=itir.SymRef(id=node.func.id),
args=[self.visit(node.args[0]), self.visit(node.args[1])],
)
else:
raise AssertionError(
"Only 'minimum' and 'maximum' builtins supported supported currently."
)
raise NotImplementedError("Only 'minimum', and 'maximum' builtins supported currently.")
15 changes: 10 additions & 5 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,10 @@ def _make_tuple(
else:
try:
data = field_or_tuple.field_getitem(named_indices)
return data
if core_defs.is_scalar_type(data):
return data # type: ignore[return-value] # type assessed in if
assert data.ndim == 0
return data.item()
except embedded_exceptions.IndexOutOfBounds:
return _UNDEFINED
else:
Expand Down Expand Up @@ -1086,12 +1089,12 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32:
def restrict(self, item: common.AnyIndexSpec) -> common.Field:
if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off
d, r = item[0]
assert d == self._dimension
assert isinstance(r, int)
return self.dtype.scalar_type(r)
return np.asarray(r) # type: ignore[return-value] # Field is a superset
havogt marked this conversation as resolved.
Show resolved Hide resolved
# TODO set a domain...
raise NotImplementedError()

Expand Down Expand Up @@ -1205,9 +1208,11 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
def restrict(self, item: common.AnyIndexSpec) -> common.Field:
# TODO set a domain...
return self._value
if core_defs.is_scalar_type(self._value):
return np.asarray(self._value) # type: ignore[return-value] # Field is a superset
return self._value # type: ignore[return-value] # scalar type in if statement above

__call__ = remap
__getitem__ = restrict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]:
cases.verify(cartesian_case, testee, inp, out=out, ref=expected)


def test_single_value_field(cartesian_case):
@gtx.field_operator
def testee_fo(a: cases.IKField) -> cases.IKField:
return a

@gtx.program
def testee_prog(a: cases.IKField):
testee_fo(a, out=a[1:2, 3:4])

a = cases.allocate(cartesian_case, testee_prog, "a")()
ref = a[1, 3]

cases.verify(cartesian_case, testee_prog, a, inout=a[1, 3], ref=ref)


def test_astype_int(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_simple_indirection(program_processor):

ref = np.zeros(shape, dtype=inp.dtype)
for i in range(shape[0]):
ref[i] = inp.ndarray[i + 1 - 1] if cond[i] < 0.0 else inp.ndarray[i + 1 + 1]
ref[i] = inp.asnumpy()[i + 1 - 1] if cond.asnumpy()[i] < 0.0 else inp.asnumpy()[i + 1 + 1]

run_processor(
conditional_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))],
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_direct_offset_for_indirection(program_processor):

ref = np.zeros(shape)
for i in range(shape[0]):
ref[i] = inp[i + cond[i]]
ref[i] = inp.asnumpy()[i + cond.asnumpy()[i]]

run_processor(
direct_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ def fencil(x, y, z, out, inp):
def naive_lap(inp):
shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]]
out = np.zeros(shape)
inp_data = inp.asnumpy()
for i in range(1, shape[0] + 1):
for j in range(1, shape[1] + 1):
for k in range(0, shape[2]):
out[i - 1, j - 1, k] = -4 * inp[i, j, k] + (
inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k]
out[i - 1, j - 1, k] = -4 * inp_data[i, j, k] + (
inp_data[i + 1, j, k]
+ inp_data[i - 1, j, k]
+ inp_data[i, j + 1, k]
+ inp_data[i, j - 1, k]
)
return out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,11 @@ def test_absolute_indexing_value_return():
field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain)

named_index = ((IDim, 12), (JDim, 6))
assert common.is_field(field)
value = field[named_index]

assert isinstance(value, np.int32)
assert value == 21
assert isinstance(value, np.ndarray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should clarify this (also @egparedes):
do we care what kind the buffer is? I think we only care that the field is still a field, but the buffer could be np.int32 conceptually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that the main check should be that the field is still a field and it doesn't matter the specific type of the buffer, but I guess it should still be some ndarray-like buffer

nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
assert value == np.asarray(21)
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
Expand Down Expand Up @@ -568,14 +569,17 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain):

@pytest.mark.parametrize(
"index, expected_value",
[((1, 0), 10), ((0, 1), 1)],
[
((1, 0), 10),
((0, 1), 1),
],
)
def test_relative_indexing_value_return(index, expected_value):
domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12)))
field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain)
indexed_field = field[index]

assert indexed_field == expected_value
assert indexed_field.item() == expected_value


@pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]])
Expand Down
Loading