Skip to content

Commit

Permalink
Add fallback to legacy state machines to compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 19, 2023
1 parent 4744d08 commit 41a0abf
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _get_storage_from_parent(data_name: str, sdfg: SDFG) -> dtypes.StorageType:
"""
nsdfg_node = sdfg.parent_nsdfg_node
parent_state = sdfg.parent
parent_sdfg = parent_state.parent
parent_sdfg = parent_state.sdfg

# Find data descriptor in parent SDFG
if data_name in nsdfg_node.in_connectors:
Expand Down
5 changes: 2 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,8 @@ def label(self):
def __label__(self, sdfg, state):
return self.data

def desc(self, sdfg):
from dace.sdfg import SDFGState, ScopeSubgraphView
if isinstance(sdfg, (SDFGState, ScopeSubgraphView)):
def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']):
if not isinstance(sdfg, dace.sdfg.SDFG):
sdfg = sdfg.sdfg
return sdfg.arrays[self.data]

Expand Down
5 changes: 5 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':

# Importing these outside creates an import loop
from dace.codegen import codegen, compiler
from dace.sdfg import utils as sdutils

# Compute build folder path before running codegen
build_folder = self.build_folder
Expand All @@ -2029,6 +2030,10 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
############################
# DaCe Compilation Process #

# Convert any scope blocks to old-school state machines for now.
# TODO: Adapt codegen to deal wiht scope blocks instead.
sdutils.inline_loop_blocks(self)

if self._regenerate_code or not os.path.isdir(build_folder):
# Clone SDFG as the other modules may modify its contents
sdfg = copy.deepcopy(self)
Expand Down
10 changes: 5 additions & 5 deletions dace/transformation/passes/fusion_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,19 @@ def modifies(self) -> ppl.Modifies:

def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]:
modified = 0
for node, state in sdfg.all_nodes_recursive():
for node, parent in sdfg.all_nodes_recursive():
if not isinstance(node, nodes.NestedSDFG):
continue
was_modified = False
if node.sdfg.parent_nsdfg_node is not node:
was_modified = True
node.sdfg.parent_nsdfg_node = node
if node.sdfg.parent is not state:
if node.sdfg.parent is not parent:
was_modified = True
node.sdfg.parent = state
if node.sdfg.parent_sdfg is not state.parent:
node.sdfg.parent = parent
if node.sdfg.parent_sdfg is not parent.sdfg:
was_modified = True
node.sdfg.parent_sdfg = state.parent
node.sdfg.parent_sdfg = parent.sdfg

if was_modified:
modified += 1
Expand Down

0 comments on commit 41a0abf

Please sign in to comment.