Skip to content

Commit

Permalink
Address more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Jun 18, 2024
1 parent 0238393 commit 31818f8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 25 deletions.
4 changes: 3 additions & 1 deletion dace/transformation/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,9 @@ def _make_function_blocksafe(cls: ppl.Pass, function_name: str, get_sdfg_arg: Ca
if hasattr(cls, function_name):
vanilla_method = getattr(cls, function_name)
def blocksafe_wrapper(tgt, *args, **kwargs):
if kwargs and 'sdfg' in kwargs:
if isinstance(tgt, SDFG):
sdfg = tgt
elif kwargs and 'sdfg' in kwargs:
sdfg = kwargs['sdfg']
else:
sdfg = get_sdfg_arg(tgt, *args)
Expand Down
27 changes: 3 additions & 24 deletions tests/transformations/loop_to_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,39 +667,18 @@ def find_loop(sdfg: dace.SDFG, itervar: str) -> Tuple[dace.SDFGState, dace.SDFGS

sdfg0 = copy.deepcopy(sdfg)
i_guard, i_begin, i_exit = find_loop(sdfg0, 'i')
l2m1_subgraph = {
DetectLoop.loop_guard: i_guard.block_id,
DetectLoop.loop_begin: i_begin.block_id,
DetectLoop.exit_state: i_exit.block_id,
}
xf1 = LoopToMap()
xf1.setup_match(sdfg0, sdfg0.cfg_id, -1, l2m1_subgraph, 0)
xf1.apply(sdfg0, sdfg0)
LoopToMap.apply_to(sdfg0, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit)
nsdfg = next((sd for sd in sdfg0.all_sdfgs_recursive() if sd.parent is not None))
j_guard, j_begin, j_exit = find_loop(nsdfg, 'j')
l2m2_subgraph = {
DetectLoop.loop_guard: j_guard.block_id,
DetectLoop.loop_begin: j_begin.block_id,
DetectLoop.exit_state: j_exit.block_id,
}
xf2 = LoopToMap()
xf2.setup_match(nsdfg, nsdfg.cfg_id, -1, l2m2_subgraph, 0)
xf2.apply(nsdfg, nsdfg)
LoopToMap.apply_to(nsdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit)

val = np.arange(1000, dtype=np.int32).reshape(10, 10, 10).copy()
sdfg(A=val, l=5)

assert np.allclose(ref, val)

j_guard, j_begin, j_exit = find_loop(sdfg, 'j')
l2m3_subgraph = {
DetectLoop.loop_guard: j_guard.block_id,
DetectLoop.loop_begin: j_begin.block_id,
DetectLoop.exit_state: j_exit.block_id,
}
xf3 = LoopToMap()
xf3.setup_match(sdfg, sdfg.cfg_id, -1, l2m3_subgraph, 0)
xf3.apply(sdfg, sdfg)
LoopToMap.apply_to(sdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit)
# NOTE: The following fails to apply because of subset A[0:i+1], which is overapproximated.
# i_guard, i_begin, i_exit = find_loop(sdfg, 'i')
# LoopToMap.apply_to(sdfg, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit)
Expand Down

0 comments on commit 31818f8

Please sign in to comment.