Skip to content

Commit

Permalink
edit to return array from slicing instead of scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Jan 25, 2024
1 parent 9cd9879 commit 7292b2e
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 19 deletions.
18 changes: 13 additions & 5 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,14 @@ def remap(
def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
new_domain, buffer_slice = self._slice(index)

new_buffer = self.ndarray[buffer_slice]
if len(new_domain) == 0:
# TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer
return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here
if new_domain.ndim == 0:
buffer_val = self.ndarray[buffer_slice]
assert core_defs.is_scalar_type(buffer_val)
new_buffer = self._scalar_to_field(buffer_val)
else:
return self.__class__.from_array(new_buffer, domain=new_domain)
new_buffer = self.ndarray[buffer_slice]

return self.__class__.from_array(new_buffer, domain=new_domain)

__getitem__ = restrict

Expand Down Expand Up @@ -302,6 +304,12 @@ def _slice(
assert common.is_relative_index_sequence(slice_)
return new_domain, slice_

def _scalar_to_field(self, value: core_defs.ScalarT) -> NdArrayField:
if self.array_ns == cp:
return cp.asarray(value)
else:
return np.asarray(value)


@dataclasses.dataclass(frozen=True)
class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ def broadcast(
)


@BuiltInFunction
def as_scalar(
zero_d_field: common.Field,
/,
) -> common.Field:
return common.field(np.asarray(zero_d_field.ndarray), domain=common.Domain(dims=(), ranges=()))


@WhereBuiltinFunction
def where(
mask: common.Field,
Expand Down Expand Up @@ -301,6 +309,7 @@ def impl(
"where",
"astype",
"as_offset",
"as_scalar",
] + MATH_BUILTIN_NAMES

BUILTIN_NAMES = TYPE_BUILTIN_NAMES + FUN_BUILTIN_NAMES
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,15 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
location=node.location,
)

def _visit_as_scalar(self, node: foast.Call, **kwargs) -> foast.Call:
return foast.Call(
func=node.func,
args=node.args,
kwargs=node.kwargs,
location=node.location,
type=node.args[0].type,
)

def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)
broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
E2V,
V2E,
E2VDim,
Edge,
IDim,
Ioff,
JDim,
Expand Down Expand Up @@ -320,6 +321,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]:
cases.verify(cartesian_case, testee, inp, out=out, ref=expected)


def test_as_scalar(cartesian_case):
@gtx.field_operator
def testee_fo(a: cases.IFloatField) -> cases.IFloatField:
return a

@gtx.program
def testee_prog(a: cases.IFloatField):
testee_fo(a, out=a[1:2])

a = cases.allocate(cartesian_case, testee_prog, "a")()
ref = np.asarray(a.asnumpy()[1])

cases.verify(cartesian_case, testee_prog, a, inout=a[1], ref=ref)


def test_astype_int(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
Expand Down Expand Up @@ -499,21 +515,16 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField:
@pytest.mark.uses_reduction_over_lift_expressions
def test_nested_reduction(unstructured_case):
@gtx.field_operator
def testee(a: cases.EField) -> cases.EField:
def testee(a: gtx.Field[[Edge, KDim], int]) -> gtx.Field[[Vertex, KDim], int]:
tmp = neighbor_sum(a(V2E), axis=V2EDim)
tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim)
return tmp_2
# tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim)
return tmp

cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a: np.sum(
np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1)[
unstructured_case.offset_provider["E2V"].table
],
axis=1,
),
comparison=lambda a, tmp_2: np.all(a == tmp_2),
ref=lambda a: np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1),
comparison=lambda a, tmp: np.all(a == tmp),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DimensionKind,
Field,
FieldOffset,
as_scalar,
astype,
broadcast,
common,
Expand Down Expand Up @@ -618,6 +619,18 @@ def mismatched_lit() -> Field[[TDim], "float32"]:
_ = FieldOperatorParser.apply_to_function(mismatched_lit)


def test_as_scalar():

def simple_as_scalar(a: Field[[], float64]):
return as_scalar(a)

parsed = FieldOperatorParser.apply_to_function(simple_as_scalar)

assert parsed.body.stmts[0].value.type == ts.FieldType(
dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
)


def test_broadcast_multi_dim():
ADim = Dimension("ADim")
BDim = Dimension("BDim")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,8 @@ def test_absolute_indexing_value_return():
named_index = ((IDim, 12), (JDim, 6))
value = field[named_index]

assert isinstance(value, np.int32)
assert value == 21
assert isinstance(value.asnumpy(), np.ndarray)
assert value.asnumpy() == np.asarray(21)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -568,12 +568,15 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain):

@pytest.mark.parametrize(
"index, expected_value",
[((1, 0), 10), ((0, 1), 1)],
[
((1, 0), common.field(np.asarray(10), domain=common.Domain(dims=(), ranges=()))),
((0, 1), common.field(np.asarray(1), domain=common.Domain(dims=(), ranges=()))),
],
)
def test_relative_indexing_value_return(index, expected_value):
domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12)))
field = common.field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain)
indexed_field = field[index]
indexed_field = fbuiltins.as_scalar(field[index])

assert indexed_field == expected_value

Expand Down

0 comments on commit 7292b2e

Please sign in to comment.