Skip to content

Commit

Permalink
Support reference set from tasklets
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Oct 18, 2023
1 parent 373c54c commit 9cf66f2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dace/codegen/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def configure_and_compile(program_folder, program_name=None, output_stream=None)
# Clean CMake directory and try once more
if Config.get_bool('debugprint'):
print('Cleaning CMake build folder and retrying...')
shutil.rmtree(build_folder)
shutil.rmtree(build_folder, ignore_errors=True)
os.makedirs(build_folder)
try:
_run_liveoutput(cmake_command, shell=True, cwd=build_folder, output_stream=output_stream)
Expand Down
18 changes: 13 additions & 5 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,8 @@ def process_out_memlets(self,
_, uconn, v, _, memlet = edge
if skip_wcr and memlet.wcr is not None:
continue
dst_node = dfg.memlet_path(edge)[-1].dst
dst_edge = dfg.memlet_path(edge)[-1]
dst_node = dst_edge.dst

# Target is neither a data nor a tasklet node
if isinstance(node, nodes.AccessNode) and (not isinstance(dst_node, nodes.AccessNode)
Expand Down Expand Up @@ -984,8 +985,9 @@ def process_out_memlets(self,
if isinstance(conntype, dtypes.pointer) and sdfg.arrays[memlet.data].dtype == conntype:
is_scalar = True # Pointer to pointer assignment
is_stream = isinstance(sdfg.arrays[memlet.data], data.Stream)
is_refset = isinstance(sdfg.arrays[memlet.data], data.Reference) and dst_edge.dst_conn == 'set'

if is_scalar and not memlet.dynamic and not is_stream:
if (is_scalar and not memlet.dynamic and not is_stream) or is_refset:
out_local_name = " __" + uconn
in_local_name = uconn
if not locals_defined:
Expand Down Expand Up @@ -1018,6 +1020,9 @@ def process_out_memlets(self,
if defined_type == DefinedType.Scalar:
mname = cpp.ptr(memlet.data, desc, sdfg, self._frame)
write_expr = f"{mname} = {in_local_name};"
elif defined_type == DefinedType.Pointer and is_refset:
mname = cpp.ptr(memlet.data, desc, sdfg, self._frame)
write_expr = f"{mname} = {in_local_name};"
elif (defined_type == DefinedType.ArrayInterface and not isinstance(desc, data.View)):
# Special case: No need to write anything between
# array interfaces going out
Expand Down Expand Up @@ -1503,10 +1508,13 @@ def define_out_memlet(self, sdfg, state_dfg, state_id, src_node, dst_node, edge,
cdtype = src_node.out_connectors[edge.src_conn]
if isinstance(sdfg.arrays[edge.data.data], data.Stream):
pass
elif isinstance(cdtype, dtypes.pointer):
# If pointer, also point to output
elif isinstance(cdtype, dtypes.pointer): # If pointer, also point to output
desc = sdfg.arrays[edge.data.data]
if not isinstance(desc.dtype, dtypes.pointer):

# If reference set, do not emit initial assignment
is_refset = isinstance(desc, data.Reference) and state_dfg.memlet_path(edge)[-1].dst_conn == 'set'

if not is_refset and not isinstance(desc.dtype, dtypes.pointer):
ptrname = cpp.ptr(edge.data.data, desc, sdfg, self._frame)
is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent,
dtypes.AllocationLifetime.External)
Expand Down

0 comments on commit 9cf66f2

Please sign in to comment.