Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: DaCe support for floordiv #1337

Merged
merged 5 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def itir_type_as_dace_type(type_: next_typing.Type):
"minus": "({} - {})",
"multiplies": "({} * {})",
"divides": "({} / {})",
"floordiv": "({} // {})",
"eq": "({} == {})",
"not_eq": "({} != {})",
"less": "({} < {})",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,6 @@ def program_domain(a: cases.IField, out: cases.IField):
def test_domain_input_bounds(cartesian_case):
if cartesian_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]:
pytest.xfail("FloorDiv not fully supported in gtfn.")
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: type inference failure")

lower_i = 1
upper_i = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def make_builtin_field_operator(builtin_name: str):

@pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data())
def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inputs):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")
if builtin_name == "gamma":
# numpy has no gamma function
ref_impl: Callable = np.vectorize(math.gamma)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ def arithmetic(inp1: cases.IFloatField, inp2: cases.IFloatField) -> gtx.Field[[I


def test_power(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")

@gtx.field_operator
def pow(inp1: cases.IField) -> cases.IField:
return inp1**2
Expand All @@ -74,7 +71,6 @@ def test_floordiv(cartesian_case):
if cartesian_case.backend in [
gtfn_cpu.run_gtfn,
gtfn_cpu.run_gtfn_imperative,
dace_iterator.run_dace_iterator,
]:
pytest.xfail(
"FloorDiv not yet supported."
Expand Down Expand Up @@ -201,9 +197,6 @@ def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField:


def test_basic_trig(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")

@gtx.field_operator
def basic_trig_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFloatField:
return sin(cos(inp1)) - sinh(cosh(inp2)) + tan(inp1) - tanh(inp2)
Expand All @@ -219,9 +212,6 @@ def basic_trig_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> case


def test_exp_log(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")

@gtx.field_operator
def exp_log_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFloatField:
return log(inp1) - exp(inp2)
Expand All @@ -232,9 +222,6 @@ def exp_log_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.I


def test_roots(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")

@gtx.field_operator
def roots_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFloatField:
return sqrt(inp1) - cbrt(inp2)
Expand All @@ -245,9 +232,6 @@ def roots_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFl


def test_is_values(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")

@gtx.field_operator
def is_isinf_fieldop(inp1: cases.IFloatField) -> cases.IBoolField:
return isinf(inp1)
Expand All @@ -274,9 +258,6 @@ def is_isfinite_fieldop(inp1: cases.IFloatField) -> cases.IBoolField:


def test_rounding_funs(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Bug in type inference with math builtins, breaks dace backend.")

@gtx.field_operator
def rounding_funs_fieldop(
inp1: cases.IFloatField, inp2: cases.IFloatField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,8 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices):
return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices)))


def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(
program_processor_no_dace_exec, lift_mode
):
program_processor, validate = program_processor_no_dace_exec
def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor, lift_mode):
program_processor, validate = program_processor
inp = vertex_index_field()
out = gtx.np_as_located_field(Cell)(np.zeros([9], dtype=inp.dtype))
ref = np.asarray(list(v2e_arr[c[0]][0] for c in c2e_arr))
Expand Down