Skip to content

Commit

Permalink
feat[next]: Test for local dimension in output (#1392)
Browse files Browse the repository at this point in the history
Currently only supported in field view embedded.
  • Loading branch information
havogt authored Dec 19, 2023
1 parent af33e21 commit b21dd56
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ markers = [
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_returns: tests that require backend support for tuple results',
Expand Down
3 changes: 3 additions & 0 deletions tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions"
USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator"
USES_SPARSE_FIELDS = "uses_sparse_fields"
USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output"
USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields"
USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset"
USES_TUPLE_ARGS = "uses_tuple_args"
Expand All @@ -119,6 +120,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
(USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
]
DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
(USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
Expand Down Expand Up @@ -159,4 +161,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
ProgramFormatterId.GTFN_CPP_FORMATTER: [
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
],
ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)],
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,22 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32
out=cases.allocate(unstructured_case, testee, cases.RETURN)(),
ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1),
)


@pytest.mark.uses_sparse_fields_as_output
def test_write_local_field(unstructured_case):
@gtx.field_operator
def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]:
return inp(V2E)

out = unstructured_case.as_field(
[Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table)
)
inp = cases.allocate(unstructured_case, testee, "inp")()
cases.verify(
unstructured_case,
testee,
inp,
out=out,
ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table],
)

0 comments on commit b21dd56

Please sign in to comment.