Skip to content

Commit

Permalink
Added more control to when to merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Aug 22, 2024
1 parent 946cd82 commit 093ff96
Showing 1 changed file with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

"""Implements the parallel map fusing transformation."""

from typing import Any, Union
from typing import Any, Optional, Union

import dace
from dace import properties, transformation
Expand All @@ -29,20 +29,29 @@ class ParallelMapFusion(map_fusion_helper.MapFusionHelper):
and are in the same scope.
Args:
only_if_common_ancestor: Only perform fusion if both Maps share at least one
node as direct ancestor. This will increase the locality of the merge.
only_inner_maps: Only match Maps that are internal, i.e. inside another Map.
only_toplevel_maps: Only consider Maps that are at the top.
Todo:
Add options to restrict the surrounding further, such as common input.
"""

map_entry1 = transformation.transformation.PatternNode(nodes.MapEntry)
map_entry2 = transformation.transformation.PatternNode(nodes.MapEntry)

only_if_common_ancestor = properties.Property(
dtype=bool,
default=False,
allow_none=False,
desc="Only perform fusing if the Maps share a node as parent.",
)

def __init__(
self,
only_if_common_ancestor: Optional[bool] = None,
**kwargs: Any,
) -> None:
if only_if_common_ancestor is not None:
self.only_if_common_ancestor = only_if_common_ancestor
super().__init__(**kwargs)

@classmethod
Expand Down Expand Up @@ -80,6 +89,13 @@ def can_be_applied(
if not util.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2):
return False

# Test if they have they share a node as direct ancestor.
if self.only_if_common_ancestor:
# This assumes that there is only one access node per data container in the state.
ancestors_1: set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)}
if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)):
return False

return True

def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None:
Expand Down

0 comments on commit 093ff96

Please sign in to comment.