Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add if extraction transformation #1641

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions dace/transformation/interstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .trivial_loop_elimination import TrivialLoopElimination
from .multistate_inline import InlineMultistateSDFG
from .move_assignment_outside_if import MoveAssignmentOutsideIf
from .if_extraction import IfExtraction
160 changes: 160 additions & 0 deletions dace/transformation/interstate/if_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" If extraction transformation """
from dace import SDFG, data, InterstateEdge
from dace.properties import make_properties
from dace.sdfg import utils
from dace.sdfg.nodes import NestedSDFG
from dace.sdfg.state import SDFGState
from dace.symbolic import pystr_to_symbolic
from dace.transformation import transformation


def eliminate_branch(sdfg: SDFG, initial_state: SDFGState):
"""
Eliminates all nodes that are reachable _only_ from `initial_state`.

Assumptions:
- The topmost level of each branch consists of `SDFGState` states connected by interstate edges.

Example:
- If we start from `state_1` for the following graph, only `state_1` will be removed.
initial_state
/ \\
state_1 state_2
\\ /
state_3
|
terminal_state
- If we start from `state_1` for the following graph, `state_1` and `state_3` will be removed. But after that,
starting from `state_2` will remove the other four intermediate state too.
initial_state
/ \\
state_1 state_2
| |
state_3 state_5
\\ /
state_5
|
state_6
|
terminal_state
"""
assert len(sdfg.in_edges(initial_state)) == 1
states_to_remove = {initial_state}
while states_to_remove:
assert all(isinstance(st, SDFGState) for st in states_to_remove)
new_states_to_remove = {e.dst for s in states_to_remove for e in sdfg.out_edges(s)
if len(sdfg.in_edges(e.dst)) == 1}
for s in states_to_remove:
sdfg.remove_node(s)
states_to_remove = new_states_to_remove


@make_properties
class IfExtraction(transformation.MultiStateTransformation):
"""
Detects an If statement as the root of a nested SDFG, and if so, extracts it by computing it in the outer SDFG and
replicating the state containing the nested SDFG.
"""

root_state = transformation.PatternNode(SDFGState)

@classmethod
def expressions(cls):
return [utils.node_path_graph(cls.root_state)]

def can_be_applied(self, graph, expr_index: int, sdfg, permissive=False):
if not sdfg.parent:
# Must be a nested SDFG.
return False

in_edges, out_edges = graph.in_edges(self.root_state), graph.out_edges(self.root_state)
if not (len(in_edges) == 0 and len(out_edges) == 2):
# Such an If state must have an inverted V shape.
return False

# Collect outer symbols used in the interstate edges going out of the If guard.
if_symbols = set(str(nested) for e in out_edges for s in e.data.free_symbols
for nested in pystr_to_symbolic(sdfg.parent_nsdfg_node.symbol_mapping[s]).free_symbols)

# Collect symbols available to state containing the nested SDFG.
parent_sdfg = sdfg.parent.sdfg
available_symbols = parent_sdfg.symbols.keys() | parent_sdfg.arglist().keys()
for desc in parent_sdfg.arrays.values():
available_symbols |= {str(s) for s in desc.free_symbols}
for e in sdfg.predecessor_state_transitions(sdfg.start_state):
available_symbols |= e.data.new_symbols(sdfg, available_symbols).keys()

if not if_symbols.issubset(available_symbols):
# The symbols used on the branch must be computable in the outer scope.
return False

_, wset = sdfg.parent.read_and_write_sets()
if if_symbols.intersection(wset):
# The symbols used on the branch must not be written in the parent state of the nested SDFG.
return False

return True

def apply(self, graph: SDFGState, sdfg: SDFG):
if_root_state: SDFGState = self.root_state
if_branch: SDFGState = sdfg.parent
outer_sdfg: SDFG = if_branch.sdfg
if_nested_sdfg_node: NestedSDFG = sdfg.parent_nsdfg_node

if_edge, else_edge = sdfg.out_edges(if_root_state)

# Create new state to perform the If, and have it replace the state containing the nested SDFG.
new_state = outer_sdfg.add_state()
utils.change_edge_dest(outer_sdfg, if_branch, new_state)

# Take the old state as the If branch, and create a copy to act as the else branch.
else_branch = SDFGState.from_json(if_branch.to_json(), context={'sdfg': outer_sdfg})
else_branch.label = data.find_new_name(else_branch.label, outer_sdfg._labels)
outer_sdfg.add_node(else_branch)

# Find the corresponding elements in the new state.
else_nested_sdfg_node = [n for n in else_branch.nodes() if n.label == if_nested_sdfg_node.label]
assert len(else_nested_sdfg_node) == 1
else_nested_sdfg_node = else_nested_sdfg_node[0]
else_root_state = [s for s in else_nested_sdfg_node.sdfg.states() if s.label == if_root_state.label]
assert len(else_root_state) == 1
else_root_state = else_root_state[0]

# Delete the else subgraph in the If state.
eliminate_branch(sdfg, sdfg.out_edges(if_root_state)[1].dst)
# Optimization: Delete new base state if useless.
new_base_state = sdfg.out_edges(if_root_state)[0].dst
if not new_base_state.nodes() and len(sdfg.out_edges(new_base_state)) == 1:
out_edge = sdfg.out_edges(new_base_state)[0]
if len(out_edge.data.assignments) == 0 and out_edge.data.is_unconditional():
sdfg.remove_node(new_base_state)
sdfg.remove_node(if_root_state)

# Do the opposite for Else state.
else_sdfg = else_nested_sdfg_node.sdfg
eliminate_branch(else_sdfg, else_sdfg.out_edges(else_root_state)[0].dst)
new_base_state = else_sdfg.out_edges(else_root_state)[0].dst
if len(new_base_state.nodes()) == 0 and len(else_sdfg.out_edges(new_base_state)) == 1:
out_edge = else_sdfg.out_edges(new_base_state)[0]
if len(out_edge.data.assignments) == 0 and out_edge.data.is_unconditional():
else_sdfg.remove_node(new_base_state)
else_sdfg.remove_node(else_root_state)

# Connect the If and Else state.
if_edge.data.replace_dict(if_nested_sdfg_node.symbol_mapping)
else_edge.data.replace_dict(if_nested_sdfg_node.symbol_mapping)

# Translate interstate edge assignemnts to symbol mappings.
if_nested_sdfg_node.symbol_mapping.update(if_edge.data.assignments)
else_nested_sdfg_node.symbol_mapping.update(else_edge.data.assignments)

# Connect everything.
outer_sdfg.add_edge(new_state, if_branch, InterstateEdge(if_edge.data.condition))
outer_sdfg.add_edge(new_state, else_branch, InterstateEdge(else_edge.data.condition))

# Make sure the SDFG is valid.
if not outer_sdfg.out_edges(if_branch):
outer_sdfg.add_state_after(if_branch)
for e in outer_sdfg.out_edges(if_branch):
outer_sdfg.add_edge(else_branch, e.dst, InterstateEdge(e.data.condition, e.data.assignments))
225 changes: 225 additions & 0 deletions tests/transformations/if_extraction_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import os
from copy import deepcopy

import numpy as np

import dace
from dace import SDFG, InterstateEdge, Memlet
from dace.transformation.interstate import IfExtraction


def make_branched_sdfg_that_does_not_depend_on_loop_var():
"""
Construct a simple SDFG that does not depend on symbols defined or updated in the outer state, e.g., loop variables.
"""
# First prepare the map-body.
subg = SDFG('body')
subg.add_array('tmp', (1,), dace.float32)
subg.add_symbol('outval', dace.float32)
ifh = subg.add_state('if_head')
if1 = subg.add_state('if_b1')
if2 = subg.add_state('if_b2')
ift = subg.add_state('if_tail')
subg.add_edge(ifh, if1, InterstateEdge(condition='(flag)', assignments={'outval': 1}))
subg.add_edge(ifh, if2, InterstateEdge(condition='(not flag)', assignments={'outval': 2}))
subg.add_edge(if1, ift, InterstateEdge())
subg.add_edge(if2, ift, InterstateEdge())
t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval')
tmp = ift.add_access('tmp')
ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]'))
subg.fill_scope_connectors()

# Then prepare the parent graph.
g = SDFG('prog')
g.add_array('A', (10,), dace.float32)
g.add_symbol('flag', dace.bool)
st0 = g.add_state('outer', is_start_block=True)
en, ex = st0.add_map('map', {'i': '0:10'})
body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'})
A = st0.add_access('A')
st0.add_memlet_path(en, body, memlet=Memlet())
st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]'))
st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]'))
g.fill_scope_connectors()

return g


def make_branched_sdfg_that_has_intermediate_branchlike_structure():
"""
Construct an SDFG that has this structure:
initial_state
/ \\
state_1 state_2
| |
state_3 state_5
\\ /
state_5
/ \
state_6 state_7
\\ /
terminal_state
"""
# First prepare the map-body.
subg = SDFG('body')
subg.add_array('tmp', (1,), dace.float32)
subg.add_symbol('outval', dace.float32)
ifh = subg.add_state('if_head')
if1 = subg.add_state('state_1')
if3 = subg.add_state('state_2')
if2 = subg.add_state('state_3')
if4 = subg.add_state('state_4')
if5 = subg.add_state('state_5')
if6 = subg.add_state('state_6')
if7 = subg.add_state('state_7')
ift = subg.add_state('if_tail')
subg.add_edge(ifh, if1, InterstateEdge(condition='(flag)', assignments={'outval': 1}))
subg.add_edge(ifh, if2, InterstateEdge(condition='(not flag)', assignments={'outval': 2}))
subg.add_edge(if1, if3, InterstateEdge())
subg.add_edge(if3, if5, InterstateEdge())
subg.add_edge(if2, if4, InterstateEdge())
subg.add_edge(if4, if5, InterstateEdge())
subg.add_edge(if5, if6, InterstateEdge())
subg.add_edge(if5, if7, InterstateEdge())
subg.add_edge(if6, ift, InterstateEdge())
subg.add_edge(if7, ift, InterstateEdge())
t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval')
tmp = ift.add_access('tmp')
ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]'))
subg.fill_scope_connectors()

# Then prepare the parent graph.
g = SDFG('prog')
g.add_array('A', (10,), dace.float32)
g.add_symbol('flag', dace.bool)
st0 = g.add_state('outer', is_start_block=True)
en, ex = st0.add_map('map', {'i': '0:10'})
body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'})
A = st0.add_access('A')
st0.add_memlet_path(en, body, memlet=Memlet())
st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]'))
st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]'))
g.fill_scope_connectors()

return g


def make_branched_sdfg_that_depends_on_loop_var():
"""
Construct a simple SDFG that depends on symbols defined or updated in the outer state, e.g., loop variables.
"""
# First prepare the map-body.
subg = SDFG('body')
subg.add_array('tmp', (1,), dace.float32)
subg.add_symbol('outval', dace.float32)
ifh = subg.add_state('if_head')
if1 = subg.add_state('if_b1')
if2 = subg.add_state('if_b2')
ift = subg.add_state('if_tail')
subg.add_edge(ifh, if1, InterstateEdge(condition='(i == 0)', assignments={'outval': 1}))
subg.add_edge(ifh, if2, InterstateEdge(condition='(not (i == 0))', assignments={'outval': 2}))
subg.add_edge(if1, ift, InterstateEdge())
subg.add_edge(if2, ift, InterstateEdge())
t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval')
tmp = ift.add_access('tmp')
ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]'))
subg.fill_scope_connectors()

# Then prepare the parent graph.
g = SDFG('prog')
g.add_array('A', (10,), dace.float32)
st0 = g.add_state('outer', is_start_block=True)
en, ex = st0.add_map('map', {'i': '0:10'})
body = st0.add_nested_sdfg(subg, None, {}, {'tmp'})
A = st0.add_access('A')
st0.add_memlet_path(en, body, memlet=Memlet())
st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]'))
st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]'))
pratyai marked this conversation as resolved.
Show resolved Hide resolved
g.fill_scope_connectors()

return g


def test_simple_application():
origA = np.zeros((10,), np.float32)

g = make_branched_sdfg_that_does_not_depend_on_loop_var()
g.save(os.path.join('_dacegraphs', 'simple-0.sdfg'))
g.validate()
g.compile()

# Get the expected values.
wantA_1 = deepcopy(origA)
wantA_2 = deepcopy(origA)
g(A=wantA_1, flag=True)
g(A=wantA_2, flag=False)

# Before, the outer graph had only one nested SDFG.
assert len(g.nodes()) == 1

assert g.apply_transformations_repeated([IfExtraction]) == 1
g.save(os.path.join('_dacegraphs', 'simple-1.sdfg'))
g.validate()
g.compile()

# Get the values from transformed program.
gotA_1 = deepcopy(origA)
gotA_2 = deepcopy(origA)
g(A=gotA_1, flag=True)
g(A=gotA_2, flag=False)

# But now, the outer graph have four: two copies of the original nested SDFGs and two for branch management.
assert len(g.nodes()) == 4
assert g.start_state.is_empty()

# Verify numerically.
assert all(np.equal(wantA_1, gotA_1))
assert all(np.equal(wantA_2, gotA_2))


def test_extracts_even_with_intermediate_branchlike_structure():
origA = np.zeros((10,), np.float32)

g = make_branched_sdfg_that_has_intermediate_branchlike_structure()
g.save(os.path.join('_dacegraphs', 'intermediate_branch-0.sdfg'))
g.validate()
g.compile()

# Get the expected values.
wantA_1 = deepcopy(origA)
wantA_2 = deepcopy(origA)
g(A=wantA_1, flag=True)
g(A=wantA_2, flag=False)

# Before, the outer graph had only one nested SDFG.
assert len(g.nodes()) == 1

assert g.apply_transformations_repeated([IfExtraction]) == 1
g.save(os.path.join('_dacegraphs', 'intermediate_branch-1.sdfg'))

# Get the values from transformed program.
gotA_1 = deepcopy(origA)
gotA_2 = deepcopy(origA)
g(A=gotA_1, flag=True)
g(A=gotA_2, flag=False)

# But now, the outer graph have four: two copies of the original nested SDFGs and two for branch management.
assert len(g.nodes()) == 4
assert g.start_state.is_empty()

# Verify numerically.
assert all(np.equal(wantA_1, gotA_1))
assert all(np.equal(wantA_2, gotA_2))


def test_no_extraction_due_to_dependency_on_loop_var():
g = make_branched_sdfg_that_depends_on_loop_var()
g.save(os.path.join('_dacegraphs', 'dependent-0.sdfg'))

assert g.apply_transformations_repeated([IfExtraction]) == 0


if __name__ == '__main__':
test_simple_application()
test_no_extraction_due_to_dependency_on_loop_var()
Loading