Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider data descriptor offset when propagating memlets #1461

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,15 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
if internal_memlet is None:
continue
try:
iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True)
ext_desc = parent_sdfg.arrays[iedge.data.data]
int_desc = sdfg.arrays[iedge.dst_conn]
iedge.data = unsqueeze_memlet(
internal_memlet,
iedge.data,
True,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset
)
# If no appropriate memlet found, use array dimension
for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[iedge.data.data].shape)):
if rng[1] + 1 == s:
Expand All @@ -1121,7 +1129,15 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
if internal_memlet is None:
continue
try:
oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True)
ext_desc = parent_sdfg.arrays[oedge.data.data]
int_desc = sdfg.arrays[oedge.src_conn]
oedge.data = unsqueeze_memlet(
internal_memlet,
oedge.data,
True,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset
)
# If no appropriate memlet found, use array dimension
for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[oedge.data.data].shape)):
if rng[1] + 1 == s:
Expand Down
81 changes: 64 additions & 17 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,24 @@ def apply(self, state: SDFGState, sdfg: SDFG):
if (edge not in modified_edges and edge.data.data == node.data):
for e in state.memlet_tree(edge):
if e._data.get_dst_subset(e, state):
new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_dst_subset=True)
offset = sdfg.arrays[e.data.data].offset
new_memlet = helpers.unsqueeze_memlet(e.data,
outer_edge.data,
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
e._data.dst_subset = new_memlet.subset
# NOTE: Node is source
for edge in state.out_edges(node):
if (edge not in modified_edges and edge.data.data == node.data):
for e in state.memlet_tree(edge):
if e._data.get_src_subset(e, state):
new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_src_subset=True)
offset = sdfg.arrays[e.data.data].offset
new_memlet = helpers.unsqueeze_memlet(e.data,
outer_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
e._data.src_subset = new_memlet.subset

# If source/sink node is not connected to a source/destination access
Expand Down Expand Up @@ -623,10 +633,17 @@ def _modify_access_to_access(self,
state.out_edges_by_connector(nsdfg_node, inner_data))
# Create memlet by unsqueezing both w.r.t. src and
# dst subsets
in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True)
offset = state.parent.arrays[top_edge.data.data].offset
in_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
out_memlet = helpers.unsqueeze_memlet(inner_edge.data,
matching_edge.data,
use_dst_subset=True)
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
new_memlet = in_memlet
new_memlet.other_subset = out_memlet.subset

Expand All @@ -649,10 +666,18 @@ def _modify_access_to_access(self,
state.out_edges_by_connector(nsdfg_node, inner_data))
# Create memlet by unsqueezing both w.r.t. src and
# dst subsets
in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True)
offset = state.parent.arrays[top_edge.data.data].offset
in_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
out_memlet = helpers.unsqueeze_memlet(inner_edge.data,
matching_edge.data,
use_dst_subset=True)
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)

new_memlet = in_memlet
new_memlet.other_subset = out_memlet.subset

Expand Down Expand Up @@ -687,7 +712,11 @@ def _modify_memlet_path(
if inner_edge in edges_to_ignore:
new_memlet = inner_edge.data
else:
new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data)
offset = state.parent.arrays[top_edge.data.data].offset
new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
internal_offset=offset,
external_offset=offset)
if inputs:
if inner_edge.dst in inner_to_outer:
dst = inner_to_outer[inner_edge.dst]
Expand All @@ -706,15 +735,19 @@ def _modify_memlet_path(
mtree = state.memlet_tree(new_edge)

# Modify all memlets going forward/backward
def traverse(mtree_node):
def traverse(mtree_node, state, nstate):
result.add(mtree_node.edge)
mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, top_edge.data)
offset = state.parent.arrays[top_edge.data.data].offset
mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data,
top_edge.data,
internal_offset=offset,
external_offset=offset)
for child in mtree_node.children:
traverse(child)
traverse(child, state, nstate)

result.add(new_edge)
for child in mtree.children:
traverse(child)
traverse(child, state, nstate)

return result

Expand Down Expand Up @@ -1032,7 +1065,8 @@ def _check_cand(candidates, outer_edges):

# If there are any symbols here that are not defined
# in "defined_symbols"
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys()))
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) -
set(nsdfg.symbol_mapping.keys()))
if missing_symbols:
ignore.add(cname)
continue
Expand All @@ -1041,10 +1075,13 @@ def _check_cand(candidates, outer_edges):
_check_cand(out_candidates, state.out_edges_by_connector)

# Return result, filtering out the states
return ({k: (dc(v), ind)
for k, (v, _, ind) in in_candidates.items()
if k not in ignore}, {k: (dc(v), ind)
for k, (v, _, ind) in out_candidates.items() if k not in ignore})
return ({
k: (dc(v), ind)
for k, (v, _, ind) in in_candidates.items() if k not in ignore
}, {
k: (dc(v), ind)
for k, (v, _, ind) in out_candidates.items() if k not in ignore
})

def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False):
nsdfg = self.nsdfg
Expand All @@ -1067,7 +1104,17 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]],
outer_edge = next(iter(outer_edges(nsdfg_node, aname)))
except StopIteration:
continue
new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data)

if isinstance(outer_edge.dst, nodes.NestedSDFG):
conn = outer_edge.dst_conn
else:
conn = outer_edge.src_conn
int_desc = nsdfg.arrays[conn]
ext_desc = sdfg.arrays[outer_edge.data.data]
new_memlet = helpers.unsqueeze_memlet(refine,
outer_edge.data,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset)
outer_edge.data.subset = subsets.Range([
ns if i in indices else os
for i, (os, ns) in enumerate(zip(outer_edge.data.subset, new_memlet.subset))
Expand Down
39 changes: 39 additions & 0 deletions tests/memlet_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,48 @@ def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]):
raise RuntimeError('Expected subset of outer out memlet to be [0:M, 0:N], found ' +
str(outer_out.subset))

def test_memlet_propagation_with_offsets():
code = """
PROGRAM foo
IMPLICIT NONE
REAL INP1(NBLOCKS, KLEV)
INTEGER, PARAMETER :: KLEV = 137
INTEGER, PARAMETER :: NBLOCKS = 8


CALL foo_test_function(NBLOCKS, KLEV, INP1)

END PROGRAM

SUBROUTINE foo_test_function(NBLOCKS, KLEV, INP1)
INTEGER, PARAMETER :: KLEV = 137
INTEGER, PARAMETER :: NBLOCKS = 1
REAL INP1(NBLOCKS, KLEV)

DO JN=1,NBLOCKS
DO JK=1,KLEV
INP1(JN, JK) = (JN-1) * KLEV + (JK-1)
ENDDO
ENDDO
END SUBROUTINE foo_test_function
"""

from dace.frontend.fortran import fortran_parser
from dace.transformation.interstate import LoopToMap
sdfg = fortran_parser.create_sdfg_from_string(code, "test_loop_map_parallel")

# Convert into map(NestedSDFG(map))
sdfg.simplify()
sdfg.apply_transformations_repeated([LoopToMap])

# Offsets of arrays (-1, -1) must be propagated through NestedSDFG correctly
propagate_memlets_sdfg(sdfg)
sdfg.validate()


if __name__ == '__main__':
test_conditional()
test_conditional_nested()
test_runtime_conditional()
test_nsdfg_memlet_propagation_with_one_sparse_dimension()
test_memlet_propagation_with_offsets()
Loading