Skip to content

Commit

Permalink
Decrease code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Nov 28, 2024
1 parent e3ffb3b commit dbf6caf
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions tests/transformations/gpu_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,41 +118,22 @@ def write_subset_dynamic(A: dace.int32[20, 20], x: dace.int32[20], y: dace.int32

assert np.array_equal(ref, val)


@pytest.mark.parametrize("transient", [False, True])
def test_free_tasklet_and_array(transient):
@pytest.mark.parametrize(["transient", "scalar"],
[[False, False], [False, True],
[True, False], [True, True]])
def test_free_tasklet(transient, scalar):
sdfg = dace.SDFG("assign")

state = sdfg.add_state("main")
arr_name, arr = sdfg.add_array("A", (4,), dace.float32, transient=transient)
an = state.add_access(arr_name)

t = state.add_tasklet("assign", {}, {"_out"}, "_out = 2.0")
state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A[0]"))

sdfg.validate()

sdfg.apply_gpu_transformations(
validate = True,
validate_all = True,
permissive = True,
sequential_innermaps=True,
register_transients=False,
simplify=False
)
if scalar:
arr_name, arr = sdfg.add_scalar("A", dace.float32, transient=transient)
else:
arr_name, arr = sdfg.add_array("A", (4,), dace.float32, transient=transient)

sdfg.validate()

@pytest.mark.parametrize("transient", [False, True])
def test_free_tasklet_and_scalar(transient):
sdfg = dace.SDFG("assign")

state = sdfg.add_state("main")
arr_name, arr = sdfg.add_scalar("A", dace.float32, transient=transient)
an = state.add_access(arr_name)

t = state.add_tasklet("assign", {}, {"_out"}, "_out = 2.0")
state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A"))
state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A" if scalar else "A[0]"))

sdfg.validate()

Expand All @@ -173,6 +154,6 @@ def test_free_tasklet_and_scalar(transient):
test_write_subset()
test_write_full()
test_write_subset_dynamic()
for transient in [False, True]:
test_free_tasklet_and_array(transient)
test_free_tasklet_and_array(transient)
for scalar in [False, True]:
for transient in [False, True]:
test_free_tasklet(transient, scalar)

0 comments on commit dbf6caf

Please sign in to comment.