diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 10583b90ff..6cf4cc67fd 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -116,7 +116,31 @@ def visit_FieldOperator( def visit_ScanOperator( self, node: foast.ScanOperator, **kwargs: Any ) -> itir.FunctionDefinition: - raise NotImplementedError("TODO") + # note: we don't need the axis here as this is handled by the program + # decorator + assert isinstance(node.type, ts_ffront.ScanOperatorType) + + # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. + # In iterator IR we didn't properly specify if this is legal, + # however after lift-inlining the expressions are transformed back to literals. + forward = self.visit(node.forward, **kwargs) + init = self.visit(node.init, **kwargs) + + # lower definition function + func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + new_body = func_definition.expr + + stencil_args: list[itir.Expr] = [] + assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args + for param in func_definition.params[1:]: + new_body = im.let(param.id, im.deref(param.id))(new_body) + stencil_args.append(im.ref(param.id)) + + definition = itir.Lambda(params=func_definition.params, expr=new_body) + + body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args) + + return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: raise AssertionError("Statements must always be visited in the context of a function.") @@ -324,10 +348,6 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: *lowered_args, *lowered_kwargs.values() ) - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - raise NotImplementedError("TODO") - return result raise AssertionError( diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index afe0cec402..84dd9e3f72 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,6 +54,7 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime +from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation @@ -186,6 +187,12 @@ def mapped_index( NamedFieldIndices: TypeAlias = Mapping[Tag, FieldIndex | SparsePositionEntry] +# Magic local dimension for the result of a `make_const_list`. +# A clean implementation will probably involve to tag the `make_const_list` +# with the neighborhood it is meant to be used with. +_CONST_DIM = common.Dimension(value="_CONST_DIM", kind=common.DimensionKind.LOCAL) + + @runtime_checkable class ItIterator(Protocol): """ @@ -227,6 +234,12 @@ class MutableLocatedField(LocatedField, Protocol): def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... +def _numpy_structured_value_to_tuples(value: Any) -> Any: + if _elem_dtype(value).names is not None: + return tuple(_numpy_structured_value_to_tuples(v) for v in value) + return value + + class Column(np.lib.mixins.NDArrayOperatorsMixin): """Represents a column when executed in column mode (`column_axis != None`). @@ -247,6 +260,10 @@ def dtype(self) -> np.dtype: # not directly dtype of `self.data` as that might be a structured type containing `None` return _elem_dtype(self.data[self.kstart]) + def __gt_type__(self) -> ts.TypeSpec: + elem = self.data[self.kstart] + return type_translation.from_value(_numpy_structured_value_to_tuples(elem)) + def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] # numpy type @@ -576,17 +593,20 @@ def execute_shift( for i, p in reversed(list(enumerate(new_entry))): # first shift applies to the last sparse dimensions of that axis type if p is None: - offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] - assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ - None, - common._DEFAULT_SKIP_VALUE, - ]: - return None - - new_entry[i] = index + if tag == _CONST_DIM.value: + new_entry[i] = 0 + else: + offset_implementation = offset_provider[tag] + assert isinstance(offset_implementation, common.Connectivity) + cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_int_index(cur_index) + if offset_implementation.mapped_index(cur_index, index) in [ + None, + common._DEFAULT_SKIP_VALUE, + ]: + return None + + new_entry[i] = index break # the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard return cast(IncompletePosition, pos) | {tag: new_entry} @@ -920,9 +940,9 @@ def deref(self) -> Any: return _make_tuple(self.field, position, column_axis=self.column_axis) -def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[Tag]: +def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[common.Dimension]: return [ - axis.value + axis for axis in axes if isinstance(axis, common.Dimension) and axis.kind == common.DimensionKind.LOCAL ] @@ -945,7 +965,7 @@ def make_in_iterator( new_pos: Position = pos.copy() for sparse_dim in set(sparse_dimensions): init = [None] * sparse_dimensions.count(sparse_dim) - new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused + new_pos[sparse_dim.value] = init # type: ignore[assignment] # looks like mypy is confused if column_dimension is not None: column_range = embedded_context.closure_column_range.get().unit_range # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted @@ -956,7 +976,7 @@ def make_in_iterator( ) if len(sparse_dimensions) >= 1: if len(sparse_dimensions) == 1: - return SparseListIterator(it, sparse_dimensions[0]) + return SparseListIterator(it, sparse_dimensions[0].value) else: raise NotImplementedError( f"More than one local dimension is currently not supported, got {sparse_dimensions}." @@ -1004,7 +1024,17 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if isinstance(self._ndarrayfield, common.MutableField): - self._ndarrayfield[self._translate_named_indices(named_indices)] = value + if isinstance(value, _List): + for i, v in enumerate(value): # type:ignore[var-annotated, arg-type] + self._ndarrayfield[ + self._translate_named_indices({**named_indices, value.offset.value: i}) # type: ignore[dict-item] + ] = v + elif isinstance(value, _ConstList): + self._ndarrayfield[ + self._translate_named_indices({**named_indices, _CONST_DIM.value: 0}) + ] = value.value + else: + self._ndarrayfield[self._translate_named_indices(named_indices)] = value else: raise RuntimeError("Assigment into a non-mutable Field is not allowed.") @@ -1383,7 +1413,23 @@ def impl(it: ItIterator) -> ItIterator: DT = TypeVar("DT") -class _List(tuple, Generic[DT]): ... +@dataclasses.dataclass(frozen=True) +class _List(Generic[DT]): + values: tuple[DT, ...] + offset: runtime.Offset + + def __getitem__(self, i: int): + return self.values[i] + + def __gt_type__(self) -> itir_ts.ListType: + offset_tag = self.offset.value + assert isinstance(offset_tag, str) + element_type = type_translation.from_value(self.values[0]) + assert isinstance(element_type, ts.DataType) + return itir_ts.ListType( + element_type=element_type, + offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL), + ) @dataclasses.dataclass(frozen=True) @@ -1393,6 +1439,14 @@ class _ConstList(Generic[DT]): def __getitem__(self, _): return self.value + def __gt_type__(self) -> itir_ts.ListType: + element_type = type_translation.from_value(self.value) + assert isinstance(element_type, ts.DataType) + return itir_ts.ListType( + element_type=element_type, + offset_type=_CONST_DIM, + ) + @builtins.neighbors.register(EMBEDDED) def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: @@ -1403,9 +1457,12 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: connectivity = offset_provider[offset_str] assert isinstance(connectivity, common.Connectivity) return _List( - shifted.deref() - for i in range(connectivity.max_neighbors) - if (shifted := it.shift(offset_str, i)).can_deref() + values=tuple( + shifted.deref() + for i in range(connectivity.max_neighbors) + if (shifted := it.shift(offset_str, i)).can_deref() + ), + offset=offset, ) @@ -1414,10 +1471,23 @@ def list_get(i, lst: _List[Optional[DT]]) -> Optional[DT]: return lst[i] +def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]: + offsets = set((lst.offset for lst in lists if hasattr(lst, "offset"))) + if len(offsets) == 0: + return None + if len(offsets) == 1: + return offsets.pop() + raise AssertionError("All lists must have the same offset.") + + @builtins.map_.register(EMBEDDED) def map_(op): def impl_(*lists): - return _List(map(lambda x: op(*x), zip(*lists))) + offset = _get_offset(*lists) + if offset is None: + return _ConstList(value=op(*[lst.value for lst in lists])) + else: + return _List(values=tuple(map(lambda x: op(*x), zip(*lists))), offset=offset) return impl_ @@ -1438,7 +1508,7 @@ def sten(*lists): break # we can check a single argument for length, # because all arguments share the same pattern - n = len(lst) + n = len(lst.values) res = init for i in range(n): res = fun(res, *(lst[i] for lst in lists)) @@ -1454,14 +1524,23 @@ class SparseListIterator: offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True) def deref(self) -> Any: + if self.list_offset == _CONST_DIM.value: + return _ConstList( + value=self.it.shift(*self.offsets, SparseTag(self.list_offset), 0).deref() + ) offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] assert isinstance(connectivity, common.Connectivity) return _List( - shifted.deref() - for i in range(connectivity.max_neighbors) - if (shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i)).can_deref() + values=tuple( + shifted.deref() + for i in range(connectivity.max_neighbors) + if ( + shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) + ).can_deref() + ), + offset=runtime.Offset(value=self.list_offset), ) def can_deref(self) -> bool: @@ -1654,16 +1733,6 @@ def _extract_column_range(domain) -> common.NamedRange | eve.NothingType: return eve.NOTHING -def _structured_dtype_to_typespec(structured_dtype: np.dtype) -> ts.ScalarType | ts.TupleType: - if structured_dtype.names is None: - return type_translation.from_dtype(core_defs.dtype(structured_dtype)) - return ts.TupleType( - types=[ - _structured_dtype_to_typespec(structured_dtype[name]) for name in structured_dtype.names - ] - ) - - def _get_output_type( fun: Callable, domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, @@ -1682,8 +1751,29 @@ def _get_output_type( with embedded_context.new_context(closure_column_range=col_range) as ctx: single_pos_result = ctx.run(_compute_at_position, fun, args, pos_in_domain, col_dim) assert single_pos_result is not _UNDEFINED, "Stencil contains an Out-Of-Bound access." - dtype = _elem_dtype(single_pos_result) - return _structured_dtype_to_typespec(dtype) + return type_translation.from_value(single_pos_result) + + +def _fieldspec_list_to_value( + domain: common.Domain, type_: ts.TypeSpec +) -> tuple[common.Domain, ts.TypeSpec]: + """Translate the list element type into the domain.""" + if isinstance(type_, itir_ts.ListType): + if type_.offset_type == _CONST_DIM: + return domain.insert( + len(domain), common.named_range((_CONST_DIM, 1)) + ), type_.element_type + else: + offset_provider = embedded_context.offset_provider.get() + offset_type = type_.offset_type + assert isinstance(offset_type, common.Dimension) + connectivity = offset_provider[offset_type.value] + assert isinstance(connectivity, common.Connectivity) + return domain.insert( + len(domain), + common.named_range((offset_type, connectivity.max_neighbors)), + ), type_.element_type + return domain, type_ @builtins.as_fieldop.register(EMBEDDED) @@ -1691,7 +1781,9 @@ def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.Unstruct def impl(*args): xp = field_utils.get_array_ns(*args) type_ = _get_output_type(fun, domain, [promote_scalars(arg) for arg in args]) - out = field_utils.field_from_typespec(type_, common.domain(domain), xp) + + new_domain, type_ = _fieldspec_list_to_value(common.domain(domain), type_) + out = field_utils.field_from_typespec(type_, new_domain, xp) # TODO(havogt): after updating all tests to use the new program, # we should get rid of closure and move the implementation to this function diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 7c35d552dc..0c08bf2b9d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -223,10 +223,7 @@ def apply_fieldview_transforms( ) -> itir.Program: ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = InlineLambdas.apply(ir, opcount_preserving=True) - ir = infer_domain.infer_program( - ir, - offset_provider=offset_provider, - ) + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 94a174dca4..edb56f5659 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Literal +from typing import Literal, Optional from gt4py.next import common from gt4py.next.type_system import type_specifications as ts @@ -31,6 +31,9 @@ class OffsetLiteralType(ts.TypeSpec): @dataclasses.dataclass(frozen=True) class ListType(ts.DataType): element_type: ts.DataType + # TODO(havogt): the `offset_type` is not yet used in type_inference, + # it is meant to describe the neighborhood (via the local dimension) + offset_type: Optional[common.Dimension] = None @dataclasses.dataclass(frozen=True) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 398e312af3..c4d07d7337 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -366,9 +366,6 @@ def stencil( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets(backend): - if backend == "dace:cpu": - pytest.skip("Internal compiler error in GitHub action container") - @gtscript.stencil(backend=backend) def stencil_ij( in_field: gtscript.Field[np.float_], @@ -391,9 +388,6 @@ def stencil_ijk( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets_and_while_loop(backend): - if backend == "dace:cpu": - pytest.skip("Internal compiler error in GitHub action container") - @gtscript.stencil(backend=backend) def stencil( pe1: gtscript.Field[np.float_], diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 44112f3899..d3a5744389 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -738,29 +738,10 @@ def validation(field_in, field_out, *, domain, origin): field_out[:, :, 0] = field_in[:, :, 0] -def _skip_dace_cpu_gcc_error(backends): - paramtype = type(pytest.param()) - res = [] - for b in backends: - if isinstance(b, paramtype) and b.values[0] == "dace:cpu": - res.append( - pytest.param( - *b.values, - marks=[ - *b.marks, - pytest.mark.skip("Internal compiler error in GitHub action container"), - ], - ) - ) - else: - res.append(b) - return res - - class TestVariableKRead(gt_testing.StencilTestSuite): dtypes = {"field_in": np.float32, "field_out": np.float32, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] - backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) + backends = ALL_BACKENDS symbols = { "field_in": gt_testing.field( in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] @@ -782,7 +763,7 @@ def validation(field_in, field_out, index, *, domain, origin): class TestVariableKAndReadOutside(gt_testing.StencilTestSuite): dtypes = {"field_in": np.float64, "field_out": np.float64, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] - backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) + backends = ALL_BACKENDS symbols = { "field_in": gt_testing.field( in_range=(0.1, 10), axes="IJK", boundary=[(0, 0), (0, 0), (1, 0)] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 123384a098..2c4102d5af 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -187,6 +187,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.GTIR_EMBEDDED: [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a0e72ede8d..333a2dae28 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -46,7 +46,7 @@ def __gt_allocator__( @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all ingredients for GTIR are available + next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index e3e919e52e..f26424bf0e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,8 +30,10 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.FencilDefinition) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.FencilDefinition) + assert isinstance(testee.itir, (itir.FencilDefinition, itir.Program)) + assert isinstance( + testee.with_backend(cartesian_case.backend).itir, (itir.FencilDefinition, itir.Program) + ) def test_frozen(cartesian_case): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py new file mode 100644 index 0000000000..56d52c75ae --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -0,0 +1,124 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next.embedded import context as embedded_context +from gt4py.next.iterator import embedded, runtime +from gt4py.next.iterator.builtins import ( + as_fieldop, + deref, + if_, + make_const_list, + map_, + neighbors, + plus, +) + + +E = gtx.Dimension("E") +V = gtx.Dimension("V") +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) +E2V = gtx.FieldOffset("E2V", source=V, target=(E, E2VDim)) + + +# 0 --0-- 1 --1-- 2 +e2v_arr = np.array([[0, 1], [1, 2]]) +e2v_conn = gtx.NeighborTableOffsetProvider( + table=e2v_arr, + origin_axis=E, + neighbor_axis=V, + max_neighbors=2, + has_skip_values=False, +) + + +def test_write_neighbors(): + def testee(inp): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda it: neighbors(E2V, it), domain)(inp) + + inp = gtx.as_field([V], np.arange(3)) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp) + + ref = e2v_arr + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_const_list(): + def testee(): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda: make_const_list(42.0), domain)() + + with embedded_context.new_context(offset_provider={}) as ctx: + result = ctx.run(testee) + + ref = np.asarray([[42.0], [42.0]]) + + assert result.domain.dims[0] == E + assert result.domain.dims[1] == embedded._CONST_DIM # this is implementation detail + assert result.shape[1] == 1 # this is implementation detail + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_neighbors_and_const_list(): + def testee(inp): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + as_fieldop(lambda: make_const_list(42.0), domain)(), + ) + + inp = gtx.as_field([V], np.arange(3)) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp) + + ref = e2v_arr + 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_conditional_neighbors_and_const_list(): + def testee(inp, mask): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda m, x, y: map_(if_)(deref(m), deref(x), deref(y)), domain)( + as_fieldop(lambda it: make_const_list(deref(it)), domain)(mask), + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), + ) + + inp = gtx.as_field([V], np.arange(3)) + mask_field = gtx.as_field([E], np.array([True, False])) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp, mask_field) + + ref = np.empty_like(e2v_arr, dtype=float) + ref[0, :] = e2v_arr[0, :] + ref[1, :] = 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_const_list_and_const_list(): + def testee(): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + as_fieldop(lambda: make_const_list(1.0), domain)(), + as_fieldop(lambda: make_const_list(42.0), domain)(), + ) + + with embedded_context.new_context(offset_provider={}) as ctx: + result = ctx.run(testee) + + ref = np.asarray([[43.0], [43.0]]) + + assert result.domain.dims[0] == E + assert result.domain.dims[1] == embedded._CONST_DIM # this is implementation detail + assert result.shape[1] == 1 # this is implementation detail + np.testing.assert_array_equal(result.asnumpy(), ref)