Skip to content

Commit

Permalink
Merge branch 'master' into fixed_map_fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller authored Feb 27, 2024
2 parents dff7343 + 608aa80 commit 7ac3e4c
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 9 deletions.
8 changes: 5 additions & 3 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dace.sdfg import (ScopeSubgraphView, SDFG, scope_contains_scope, is_array_stream_view, NodeNotExpandedError,
dynamic_map_inputs, local_transients)
from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga, is_in_scope
from dace.sdfg.validation import validate_memlet_data
from typing import Union
from dace.codegen.targets import fpga

Expand All @@ -40,7 +41,7 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''):
_visit_structure(v, args, f'{prefix}->{k}')
elif isinstance(v, data.ContainerArray):
_visit_structure(v.stype, args, f'{prefix}->{k}')
elif isinstance(v, data.Data):
if isinstance(v, data.Data):
args[f'{prefix}->{k}'] = v

# Keeps track of generated connectors, so we know how to access them in nested scopes
Expand Down Expand Up @@ -620,6 +621,7 @@ def copy_memory(
callsite_stream,
)


def _emit_copy(
self,
sdfg,
Expand All @@ -637,9 +639,9 @@ def _emit_copy(
orig_vconn = vconn

# Determine memlet directionality
if isinstance(src_node, nodes.AccessNode) and memlet.data == src_node.data:
if isinstance(src_node, nodes.AccessNode) and validate_memlet_data(memlet.data, src_node.data):
write = True
elif isinstance(dst_node, nodes.AccessNode) and memlet.data == dst_node.data:
elif isinstance(dst_node, nodes.AccessNode) and validate_memlet_data(memlet.data, dst_node.data):
write = False
elif isinstance(src_node, nodes.CodeNode) and isinstance(dst_node, nodes.CodeNode):
# Code->Code copy (not read nor write)
Expand Down
7 changes: 6 additions & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend:
if arr is not None:
datatypes.add(arr.dtype)

emitted = set()

def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool:
if isinstance(dtype, dtypes.pointer):
wrote_something = _emit_definitions(dtype._typeclass, wrote_something)
Expand All @@ -164,7 +166,10 @@ def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool:
if hasattr(dtype, 'emit_definition'):
if not wrote_something:
global_stream.write("", sdfg)
global_stream.write(dtype.emit_definition(), sdfg)
if dtype not in emitted:
global_stream.write(dtype.emit_definition(), sdfg)
wrote_something = True
emitted.add(dtype)
return wrote_something

# Emit unique definitions
Expand Down
6 changes: 4 additions & 2 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,8 +1449,10 @@ def validate_name(name):
return False
if name in {'True', 'False', 'None'}:
return False
if namere.match(name) is None:
return False
tokens = name.split('.')
for token in tokens:
if namere.match(token) is None:
return False
return True


Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,10 @@ def edges(self):
return [DiGraph._from_nx(e) for e in self._nx.edges()]

def in_edges(self, node):
return [DiGraph._from_nx(e) for e in self._nx.in_edges()]
return [DiGraph._from_nx(e) for e in self._nx.in_edges(node, True)]

def out_edges(self, node):
return [DiGraph._from_nx(e) for e in self._nx.out_edges()]
return [DiGraph._from_nx(e) for e in self._nx.out_edges(node, True)]

def add_node(self, node):
return self._nx.add_node(node)
Expand Down
4 changes: 3 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,9 @@ def replace_dict(self,

# Replace in arrays and symbols (if a variable name)
if replace_keys:
for name, new_name in repldict.items():
# Filter out nested data names, as we cannot and do not want to replace names in nested data descriptors
repldict_filtered = {k: v for k, v in repldict.items() if '.' not in k}
for name, new_name in repldict_filtered.items():
if validate_name(new_name):
_replace_dict_keys(self._arrays, name, new_name)
_replace_dict_keys(self.symbols, name, new_name)
Expand Down
17 changes: 17 additions & 0 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,20 @@ def __str__(self):
locinfo += f'\nInvalid SDFG saved for inspection in {os.path.abspath(self.path)}'

return f'{self.message} (at state {state.label}{edgestr}){locinfo}'


def validate_memlet_data(memlet_data: str, access_data: str) -> bool:
""" Validates that the src/dst access node data matches the memlet data.
:param memlet_data: The data of the memlet.
:param access_data: The data of the access node.
:return: True if the memlet data matches the access node data.
"""
if memlet_data == access_data:
return True
if memlet_data is None or access_data is None:
return False
access_tokens = access_data.split('.')
memlet_tokens = memlet_data.split('.')
mem_root = '.'.join(memlet_tokens[:len(access_tokens)])
return mem_root == access_data
59 changes: 59 additions & 0 deletions tests/sdfg/data/container_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,67 @@ def test_two_levels():
assert np.allclose(ref, B[0])


def test_multi_nested_containers():

M, N = dace.symbol('M'), dace.symbol('N')
sdfg = dace.SDFG('tester')
float_desc = dace.data.Scalar(dace.float32)
E_desc = dace.data.Structure({'F': dace.float32[N], 'G':float_desc}, 'InnerStruct')
B_desc = dace.data.ContainerArray(E_desc, [M])
A_desc = dace.data.Structure({'B': B_desc, 'C': dace.float32[M], 'D': float_desc}, 'OuterStruct')
sdfg.add_datadesc('A', A_desc)
sdfg.add_datadesc_view('vB', B_desc)
sdfg.add_datadesc_view('vE', E_desc)
sdfg.add_array('out', [M, N], dace.float32)

state = sdfg.add_state()
rA = state.add_read('A')
vB = state.add_access('vB')
vE = state.add_access('vE')
wout = state.add_write('out')

me, mx = state.add_map('outer_product', dict(i='0:M', j='0:N'))
tasklet = state.add_tasklet('outer_product', {'__in_A_B_E_F', '__in_A_B_E_G', '__in_A_C', '__in_A_D'}, {'__out'},
'__out = (__in_A_B_E_F + __in_A_B_E_G) * (__in_A_C + __in_A_D)')

state.add_edge(rA, None, vB, 'views', dace.Memlet('A.B'))
state.add_memlet_path(vB, me, vE, dst_conn='views', memlet=dace.Memlet('vB[i]'))
state.add_edge(vE, None, tasklet, '__in_A_B_E_F', dace.Memlet('vE.F[j]'))
state.add_edge(vE, None, tasklet, '__in_A_B_E_G', dace.Memlet(data='vE.G', subset='0'))
state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_C', memlet=dace.Memlet('A.C[i]'))
state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_D', memlet=dace.Memlet(data='A.D', subset='0'))
state.add_memlet_path(tasklet, mx, wout, src_conn='__out', memlet=dace.Memlet('out[i, j]'))

c_data = np.arange(5, dtype=np.float32)
f_data = np.arange(5 * 3, dtype=np.float32).reshape(5, 3)

e_class = E_desc.dtype._typeclass.as_ctypes()
b_obj = []
b_data = np.ndarray((5, ), dtype=ctypes.c_void_p)
for i in range(5):
f_obj = f_data[i].__array_interface__['data'][0]
e_obj = e_class(F=f_obj, G=ctypes.c_float(0.1))
b_obj.append(e_obj) # NOTE: This is needed to keep the object alive ...
b_data[i] = ctypes.addressof(e_obj)
a_dace = A_desc.dtype._typeclass.as_ctypes()(B=b_data.__array_interface__['data'][0],
C=c_data.__array_interface__['data'][0],
D=ctypes.c_float(0.2))




out_dace = np.empty((5, 3), dtype=np.float32)
ref = np.empty((5, 3), dtype=np.float32)
for i in range(5):
ref[i] = (f_data[i] + 0.1) * (c_data[i] + 0.2)

sdfg(A=a_dace, out=out_dace, M=5, N=3)
assert np.allclose(out_dace, ref)


if __name__ == '__main__':
test_read_struct_array()
test_write_struct_array()
test_jagged_container_array()
test_two_levels()
test_multi_nested_containers()

0 comments on commit 7ac3e4c

Please sign in to comment.