Skip to content

Commit

Permalink
Realized that I can not use SDFG.shared_transient() for detection i…
Browse files Browse the repository at this point in the history
…f data can be removed.

This is because the function is much less strict.
  • Loading branch information
philip-paul-mueller committed Nov 1, 2024
1 parent 3453c6c commit 90731af
Showing 1 changed file with 74 additions and 9 deletions.
83 changes: 74 additions & 9 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,14 +1150,17 @@ def is_shared_data(
data: nodes.AccessNode,
sdfg: dace.SDFG,
) -> bool:
"""Tests if `data` is interstate data, an can not be removed.
"""Tests if `data` is shared data, an can not be removed.
Interstate data is used to transmit data between multiple state or by
extension within the state. Thus it must be classified as a shared output.
This function will go through the SDFG to and collect the names of all data
container that should be classified as shared. Note that this is an over
approximation as it does not take the location into account, i.e. "is no longer
used".
Interstate data is used to transmit data, this includes:
- The data is referred in multiple states.
- The data is referred to multiple times in the same state, either the state
has multiple access nodes for that data or an access node has an
out degree larger than one.
- The data is read inside interstate edges.
This definition is stricter than the one employed by `SDFG.shared_transients()`,
as it also includes usage within a state.
Args:
transient: The transient that should be checked.
Expand All @@ -1166,7 +1169,7 @@ def is_shared_data(
Note:
The function computes the this set once for every SDFG and then caches it.
There is no mechanism to detect if the cache must be evicted. However,
as long as no additional data is added, there is no problem.
as long as no additional data is added to the SDFG, there is no problem.
"""
if sdfg not in self._shared_data:
self._compute_shared_data_in(sdfg)
Expand All @@ -1184,7 +1187,69 @@ def _compute_shared_data_in(
Args:
sdfg: The SDFG for which the set of shared data should be computed.
"""
self._shared_data[sdfg] = set(sdfg.shared_transients())
# Shared data of this SDFG.
shared_data: Set[str] = set()

# All global data can not be removed, so it must always be shared.
for data_name, data_desc in sdfg.arrays.items():
if not data_desc.transient:
shared_data.add(data_name)
elif isinstance(data_desc, dace.data.Scalar):
shared_data.add(data_name)

# We go through all states and classify the nodes/data:
# - Data is referred to in different states.
# - The access node is a view (both have to survive).
# - Transient sink or source node.
# - The access node has output degree larger than 1 (input degrees larger
# than one, will always be partitioned as shared anyway).
prevously_seen_data: Set[str] = set()
for state in sdfg.nodes():
for access_node in state.data_nodes():

if access_node.data in shared_data:
# The data was already classified to be shared data
pass

elif access_node.data in prevously_seen_data:
# We have seen this data before, either in this state or in
# a previous one, but we did not classifies it as shared back then
shared_data.add(access_node.data)

if state.in_degree(access_node) == 0:
# (Transient) sink nodes are used in other states, or simplify
# will get rid of them.
shared_data.add(access_node.data)

elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1
# The access node is either a source node (it is shared in another
# state) or the node has a degree larger than one, so it is used
# in this state somewhere else.
shared_data.add(access_node.data)

elif self.is_view(node=access_node, sdfg=sdfg):
# To ensure that the write to the view happens, both have to be shared.
viewed_data: str = self.track_view(view=access_node, state=state, sdfg=sdfg).data
shared_data.update([access_node.data, viewed_data])
prevously_seen_data.update([access_node.data, viewed_data])

else:
# The node was not classified as shared data, so we record that
# we saw it. Note that a node that was immediately classified
# as shared node will never be added to this set, but a data
# that was found twice will be inside this list.
prevously_seen_data.add(access_node.data)

# Now we collect all symbols that are read in interstate edges.
# Because, they might refer to data inside states and must be kept alive.
interstate_read_symbols: Set[str] = set()
for edge in sdfg.edges():
interstate_read_symbols.update(edge.data.read_symbols())
data_read_in_interstate_edges = interstate_read_symbols.intersection(prevously_seen_data)

# Compute the final set of shared data and update the internal cache.
shared_data.update(data_read_in_interstate_edges)
self._shared_data[sdfg] = shared_data


def _compute_multi_write_data(
Expand Down

0 comments on commit 90731af

Please sign in to comment.