Skip to content

Commit

Permalink
Check host_maps and host_data in the GPU transformations (#1701)
Browse files Browse the repository at this point in the history
Co-authored-by: Tal Ben-Nun <[email protected]>
  • Loading branch information
ThrudPrimrose and tbennun authored Nov 29, 2024
1 parent 4f8eb92 commit af87662
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 15 deletions.
14 changes: 9 additions & 5 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def clear_data_reports(self):

def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, **kwargs):
"""
Invokes an SDFG with an instrumented data report, generating and compiling code if necessary.
Invokes an SDFG with an instrumented data report, generating and compiling code if necessary.
Arguments given as ``args`` and ``kwargs`` will be overriden by the data containers defined in the report.
:param dreport: The instrumented data report to use upon calling.
Expand Down Expand Up @@ -2690,7 +2690,7 @@ def apply_transformations_once_everywhere(self,
print_report: Optional[bool] = None,
order_by_transformation: bool = True,
progress: Optional[bool] = None) -> int:
"""
"""
This function applies a transformation or a set of (unique) transformations
until throughout the entire SDFG once. Operates in-place.
Expand Down Expand Up @@ -2738,7 +2738,9 @@ def apply_gpu_transformations(self,
permissive=False,
sequential_innermaps=True,
register_transients=True,
simplify=True):
simplify=True,
host_maps=None,
host_data=None):
""" Applies a series of transformations on the SDFG for it to
generate GPU code.
Expand All @@ -2755,7 +2757,9 @@ def apply_gpu_transformations(self,
self.apply_transformations(GPUTransformSDFG,
options=dict(sequential_innermaps=sequential_innermaps,
register_trans=register_transients,
simplify=simplify),
simplify=simplify,
host_maps=host_maps,
host_data=host_data),
validate=validate,
validate_all=validate_all,
permissive=permissive,
Expand Down Expand Up @@ -2806,7 +2810,7 @@ def expand_library_nodes(self, recursive=True):

def generate_code(self):
""" Generates code from this SDFG and returns it.
:return: A list of `CodeObject` objects containing the generated
code of different files and languages.
"""
Expand Down
55 changes: 45 additions & 10 deletions dace/transformation/interstate/gpu_transform_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dace.sdfg import nodes, scope
from dace.sdfg import utils as sdutil
from dace.transformation import transformation, helpers as xfh
from dace.properties import Property, make_properties
from dace.properties import ListProperty, Property, make_properties
from collections import defaultdict
from copy import deepcopy as dc
from sympy import floor
Expand Down Expand Up @@ -128,6 +128,12 @@ class GPUTransformSDFG(transformation.MultiStateTransformation):
dtype=str,
default='')

host_maps = ListProperty(desc='List of map GUIDs, the passed maps are not offloaded to the GPU',
element_type=str, default=None, allow_none=True)

host_data = ListProperty(desc='List of data names, the passed data are not offloaded to the GPU',
element_type=str, default=None, allow_none=True)

@staticmethod
def annotates_memlets():
# Skip memlet propagation for now
Expand All @@ -154,19 +160,44 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
return False
return True

def apply(self, _, sdfg: sd.SDFG):
def _get_marked_inputs_and_outputs(self, state, entry_node) -> list:
if not self.host_data and not self.host_maps:
return []
marked_sources = [state.memlet_tree(e).root().edge.src for e in state.in_edges(entry_node)]
marked_sources = [sdutil.get_view_node(state, node) if isinstance(node, data.View) else node for node in marked_sources]
marked_destinations = [state.memlet_tree(e).root().edge.dst for e in state.in_edges(state.exit_node(entry_node))]
marked_destinations = [sdutil.get_view_node(state, node) if isinstance(node, data.View) else node for node in marked_destinations]
marked_accesses = [n.data for n in (marked_sources + marked_destinations) if n is not None and isinstance(n, nodes.AccessNode) and n.data in self.host_data]
return marked_accesses

def _output_or_input_is_marked_host(self, state, entry_node) -> bool:
marked_accesses = self._get_marked_inputs_and_outputs(state, entry_node)
return len(marked_accesses) > 0


def apply(self, _, sdfg: sd.SDFG):
#######################################################
# Step 0: SDFG metadata

# Find all input and output data descriptors
input_nodes = []
output_nodes = []
global_code_nodes: Dict[sd.SDFGState, nodes.Tasklet] = defaultdict(list)
if self.host_maps is None:
self.host_maps = []
if self.host_data is None:
self.host_data = []

# Propagate memlets to ensure that we can find the true array subsets that are written.
propagate_memlets_sdfg(sdfg)

# Input and ouputs of all host_maps need to be marked as host_data
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, nodes.EntryNode) and node.guid in self.host_maps:
accesses = self._get_marked_inputs_and_outputs(state, node)
self.host_data.extend(accesses)

for state in sdfg.nodes():
sdict = state.scope_dict()
for node in state.nodes():
Expand All @@ -176,12 +207,13 @@ def apply(self, _, sdfg: sd.SDFG):
# map ranges must stay on host
for e in state.out_edges(node):
last_edge = state.memlet_path(e)[-1]
if (isinstance(last_edge.dst, nodes.EntryNode) and last_edge.dst_conn
and not last_edge.dst_conn.startswith('IN_') and sdict[last_edge.dst] is None):
if (isinstance(last_edge.dst, nodes.EntryNode) and ((last_edge.dst_conn
and not last_edge.dst_conn.startswith('IN_') and sdict[last_edge.dst] is None) or
(last_edge.dst in self.host_maps))):
break
else:
input_nodes.append((node.data, node.desc(sdfg)))
if (state.in_degree(node) > 0 and node.data not in output_nodes):
if (state.in_degree(node) > 0 and node.data not in output_nodes and node.data not in self.host_data):
output_nodes.append((node.data, node.desc(sdfg)))

# Input nodes may also be nodes with WCR memlets and no identity
Expand Down Expand Up @@ -312,11 +344,13 @@ def apply(self, _, sdfg: sd.SDFG):
for node in state.nodes():
if sdict[node] is None:
if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)):
node.schedule = dtypes.ScheduleType.GPU_Default
gpu_nodes.add((state, node))
if node.guid:
node.schedule = dtypes.ScheduleType.GPU_Default
gpu_nodes.add((state, node))
elif isinstance(node, nodes.EntryNode):
node.schedule = dtypes.ScheduleType.GPU_Device
gpu_nodes.add((state, node))
if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node):
node.schedule = dtypes.ScheduleType.GPU_Device
gpu_nodes.add((state, node))
elif self.sequential_innermaps:
if isinstance(node, (nodes.EntryNode, nodes.LibraryNode)):
node.schedule = dtypes.ScheduleType.Sequential
Expand Down Expand Up @@ -423,7 +457,8 @@ def apply(self, _, sdfg: sd.SDFG):
continue

# NOTE: the cloned arrays match too but it's the same storage so we don't care
nodedesc.storage = dtypes.StorageType.GPU_Global
if node.data not in self.host_data:
nodedesc.storage = dtypes.StorageType.GPU_Global

# Try to move allocation/deallocation out of loops
dsyms = set(map(str, nodedesc.free_symbols))
Expand Down
145 changes: 145 additions & 0 deletions tests/host_map_host_data_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import dace
import pytest

def create_assign_sdfg():
sdfg = dace.SDFG('single_iteration_map')
state = sdfg.add_state()
array_size = 1
A, _ = sdfg.add_array('A', [array_size], dace.float32)
map_entry, map_exit = state.add_map('map_1_iter', {'i': '0:1'})
tasklet = state.add_tasklet('set_to_1', {}, {'OUT__a'}, '_a = 1')
map_exit.add_in_connector('IN__a')
map_exit.add_out_connector('OUT__a')
tasklet.add_out_connector('OUT__a')
an = state.add_write('A')
state.add_edge(map_entry, None, tasklet, None, dace.Memlet())
state.add_edge(tasklet, 'OUT__a', map_exit, 'IN__a', dace.Memlet(f'A[0]'))
state.add_edge(map_exit, 'OUT__a', an, None, dace.Memlet(f'A[0]'))
sdfg.validate()
return A, sdfg

def create_assign_sdfg_with_views():
sdfg = dace.SDFG('single_iteration_map')
state = sdfg.add_state()
array_size = 5
_, _ = sdfg.add_array('A', [array_size], dace.float32)
v_A, _ = sdfg.add_view('v_A', [1], dace.float32)
map_entry, map_exit = state.add_map('map_1_iter', {'i': '0:1'})
tasklet = state.add_tasklet('set_to_1', {}, {'OUT__a'}, '_a = 1')
map_exit.add_in_connector('IN__a')
map_exit.add_out_connector('OUT__a')
tasklet.add_out_connector('OUT__a')
an = state.add_write('v_A')
an2 = state.add_write('A')
an.add_out_connector('views')
state.add_edge(map_entry, None, tasklet, None, dace.Memlet(None))
state.add_edge(tasklet, 'OUT__a', map_exit, 'IN__a', dace.Memlet(f'v_A[0]'))
state.add_edge(map_exit, 'OUT__a', an, None, dace.Memlet(f'v_A[0]'))
state.add_edge(an, 'views', an2, None, dace.Memlet(f'A[0:1]'))
sdfg.validate()
return v_A, sdfg

def create_increment_sdfg():
sdfg = dace.SDFG('increment_map')
state = sdfg.add_state()
array_size = 500
A, _ = sdfg.add_array('A', [array_size], dace.float32)
map_entry, map_exit = state.add_map('map_1_iter', {'i': f'0:{array_size}'})
tasklet = state.add_tasklet('inc_by_1', {}, {'OUT__a'}, '_a = _a + 1')
map_entry.add_in_connector('IN__a')
map_entry.add_out_connector('OUT__a')
map_exit.add_in_connector('IN__a')
map_exit.add_out_connector('OUT__a')
tasklet.add_in_connector('IN__a')
tasklet.add_out_connector('OUT__a')
an1 = state.add_read('A')
an2 = state.add_write('A')
state.add_edge(an1, None, map_entry, 'IN__a', dace.Memlet(f'A[i]'))
state.add_edge(map_entry, 'OUT__a', tasklet, 'IN__a', dace.Memlet())
state.add_edge(tasklet, 'OUT__a', map_exit, 'IN__a', dace.Memlet(f'A[i]'))
state.add_edge(map_exit, 'OUT__a', an2, None, dace.Memlet(f'A[i]'))
sdfg.validate()
return A, sdfg

def create_increment_sdfg_with_views():
sdfg = dace.SDFG('increment_map')
state = sdfg.add_state()
array_size = 500
view_size = 100
_, _ = sdfg.add_array('A', [array_size], dace.float32)
v_A, _ = sdfg.add_view('v_A', [view_size], dace.float32)
map_entry, map_exit = state.add_map('map_1_iter', {'i': f'0:{view_size}'})
tasklet = state.add_tasklet('inc_by_1', {}, {'OUT__a'}, '_a = _a + 1')
map_entry.add_in_connector('IN__a')
map_entry.add_out_connector('OUT__a')
map_exit.add_in_connector('IN__a')
map_exit.add_out_connector('OUT__a')
tasklet.add_in_connector('IN__a')
tasklet.add_out_connector('OUT__a')
an1 = state.add_read('A')
an2 = state.add_write('A')
an3 = state.add_read('v_A')
an4 = state.add_write('v_A')
an3.add_in_connector('views')
an4.add_out_connector('views')
state.add_edge(an1, None, an3, 'views', dace.Memlet(f'A[0:100]'))
state.add_edge(an3, None, map_entry, 'IN__a', dace.Memlet(f'v_A[i]'))
state.add_edge(map_entry, 'OUT__a', tasklet, 'IN__a', dace.Memlet('v_A[i]'))
state.add_edge(tasklet, 'OUT__a', map_exit, 'IN__a', dace.Memlet(f'v_A[i]'))
state.add_edge(map_exit, 'OUT__a', an4, None, dace.Memlet(f'v_A[i]'))
state.add_edge(an4, 'views', an2, None, dace.Memlet(f'A[0:100]'))
sdfg.validate()
return v_A, sdfg

@pytest.mark.parametrize("sdfg_creator", [
create_assign_sdfg,
create_increment_sdfg,
create_assign_sdfg_with_views,
create_increment_sdfg_with_views
])
class TestHostDataHostMapParams:
def test_host_data(self, sdfg_creator):
"""Test that arrays marked as host_data remain on host after GPU transformation."""
A, sdfg = sdfg_creator()
sdfg.apply_gpu_transformations(host_data=[A])
sdfg.validate()

assert sdfg.arrays['A'].storage != dace.dtypes.StorageType.GPU_Global

def test_host_map(self, sdfg_creator):
"""Test that maps marked as host_maps remain on host after GPU transformation."""
A, sdfg = sdfg_creator()
host_maps = [
n.guid for s in sdfg.states()
for n in s.nodes()
if isinstance(n, dace.nodes.EntryNode)
]
sdfg.apply_gpu_transformations(host_maps=host_maps)
sdfg.validate()
assert sdfg.arrays['A'].storage != dace.dtypes.StorageType.GPU_Global

@pytest.mark.parametrize("pass_empty", [True, False])
def test_no_host_map_or_data(self, sdfg_creator, pass_empty):
"""Test default GPU transformation behavior with no host constraints."""
A, sdfg = sdfg_creator()

if pass_empty:
sdfg.apply_gpu_transformations(host_maps=[], host_data=[])
else:
sdfg.apply_gpu_transformations()

sdfg.validate()

# Verify array storage locations
assert 'A' in sdfg.arrays and 'gpu_A' in sdfg.arrays
assert sdfg.arrays['A'].storage != dace.dtypes.StorageType.GPU_Global
assert sdfg.arrays['gpu_A'].storage == dace.dtypes.StorageType.GPU_Global

# Verify map schedules
for s in sdfg.states():
for n in s.nodes():
if isinstance(n, dace.nodes.MapEntry):
assert n.map.schedule == dace.ScheduleType.GPU_Device

if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit af87662

Please sign in to comment.