Skip to content

Commit

Permalink
Add more test cases and fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Nov 20, 2024
1 parent 506d0aa commit b956142
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 60 deletions.
8 changes: 8 additions & 0 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ def memlet_copy_to_absolute_strides(dispatcher: 'TargetDispatcher',
elif memlet.data == dst_node.data:
copy_shape, src_strides = reshape_strides(dst_subset, dst_strides, src_strides, copy_shape)

def replace_dace_defer_dim(string, arrname):
pattern = r"__dace_defer_dim(\d+)"
return re.sub(pattern, r"A_size[\1]", string)

# TODO: do this better?
dst_expr = replace_dace_defer_dim(dst_expr, dst_node.data) if dst_expr is not None else None
src_expr = replace_dace_defer_dim(src_expr, src_node.data) if src_expr is not None else None

return copy_shape, src_strides, dst_strides, src_expr, dst_expr


Expand Down
69 changes: 42 additions & 27 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,27 +756,34 @@ def _emit_copy(

if isinstance(dst_node, nodes.Tasklet):
# Copy into tasklet
desc = sdfg.arrays[memlet.data]
deferred_size_names = self._get_deferred_size_names(desc, memlet)
stream.write(
" " + self.memlet_definition(sdfg, memlet, False, vconn, dst_node.in_connectors[vconn]),
" " + self.memlet_definition(sdfg, memlet, False, vconn, dst_node.in_connectors[vconn], deferred_size_names=deferred_size_names),
cfg,
state_id,
[src_node, dst_node],
)
stream.write(
"//u1"
)
if deferred_size_names is not None:
stream.write(
"// Size uses deferred allocation"
)

return
elif isinstance(src_node, nodes.Tasklet):
# Copy out of tasklet
desc = sdfg.arrays[memlet.data]
deferred_size_names = self._get_deferred_size_names(desc, memlet)
stream.write(
" " + self.memlet_definition(sdfg, memlet, True, uconn, src_node.out_connectors[uconn]),
" " + self.memlet_definition(sdfg, memlet, True, uconn, src_node.out_connectors[uconn], deferred_size_names=deferred_size_names),
cfg,
state_id,
[src_node, dst_node],
)
stream.write(
"//u2"
)
if deferred_size_names is not None:
stream.write(
"// Size uses deferred allocation"
)
return
else: # Copy array-to-array
src_nodedesc = src_node.desc(sdfg)
Expand Down Expand Up @@ -875,6 +882,7 @@ def _emit_copy(

state_dfg: SDFGState = cfg.nodes()[state_id]


copy_shape, src_strides, dst_strides, src_expr, dst_expr = cpp.memlet_copy_to_absolute_strides(
self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, self._packed_types)

Expand Down Expand Up @@ -1043,6 +1051,27 @@ def write_and_resolve_expr(self, sdfg: SDFG, memlet: mmlt.Memlet, nc: bool, outn
custom_reduction = cpp.unparse_cr(sdfg, memlet.wcr, dtype)
return (f'dace::wcr_custom<{dtype.ctype}>:: template {func}({custom_reduction}, {ptr}, {inname})')

def _get_deferred_size_names(self, desc, memlet):
if (desc.storage != dtypes.StorageType.GPU_Global and
desc.storage != dtypes.StorageType.CPU_Heap and
not desc.transient):
return None
def check_dace_defer(elements):
for elem in elements:
if isinstance(elem, symbolic.symbol) and str(elem).startswith("__dace_defer"):
return True
return False
deferred_size_names = None
if check_dace_defer(desc.shape):
if desc.storage == dtypes.StorageType.GPU_Global or desc.storage == dtypes.StorageType.CPU_Heap:
deferred_size_names = []
for i, elem in enumerate(desc.shape):
if str(elem).startswith("__dace_defer"):
deferred_size_names.append(f"__{memlet.data}_dim{i}_size" if desc.storage == dtypes.StorageType.GPU_Global else f"{desc.size_desc_name}[{i}]")
else:
deferred_size_names.append(elem)
return deferred_size_names if len(deferred_size_names) > 0 else None

def process_out_memlets(self,
sdfg: SDFG,
cfg: ControlFlowRegion,
Expand Down Expand Up @@ -1179,22 +1208,7 @@ def process_out_memlets(self,
# If the storage type if CPU_Heap or GPU_Global then it might be requiring deferred allocation
# We can check if the array requires sepcial access using A_size[0] (CPU) or __A_dim0_size (GPU0)
# by going through the shape and checking for symbols starting with __dace_defer
def check_dace_defer(elements):
for elem in elements:
if isinstance(elem, symbolic.symbol) and str(elem).startswith("__dace_defer"):
return True
return False
deferred_size_names = None
if check_dace_defer(desc.shape):
if desc.storage == dtypes.StorageType.GPU_Global or desc.storage == dtypes.StorageType.CPU_Heap:
deferred_size_names = []
for i, elem in enumerate(desc.shape):
if str(elem).startswith("__dace_defer"):
deferred_size_names.append(f"__{memlet.data}_dim{i}_size" if desc.storage == dtypes.StorageType.GPU_Global else f"{desc.size_desc_name}[{i}]")
else:
deferred_size_names.append(elem)
else:
raise Exception("Deferred Allocation only supported on array storages of type GPU_Global or CPU_Heap")
deferred_size_names = self._get_deferred_size_names(desc, memlet)
expr = cpp.cpp_array_expr(sdfg, memlet, codegen=self._frame, deferred_size_names=deferred_size_names)
write_expr = codegen.make_ptr_assignment(in_local_name, conntype, expr, desc_dtype)

Expand Down Expand Up @@ -1332,7 +1346,8 @@ def memlet_definition(self,
local_name: str,
conntype: Union[data.Data, dtypes.typeclass] = None,
allow_shadowing: bool = False,
codegen: 'CPUCodeGen' = None):
codegen: 'CPUCodeGen' = None,
deferred_size_names = None):
# TODO: Robust rule set
if conntype is None:
raise ValueError('Cannot define memlet for "%s" without connector type' % local_name)
Expand Down Expand Up @@ -1381,7 +1396,7 @@ def memlet_definition(self,
decouple_array_interfaces=decouple_array_interfaces)

result = ''
expr = (cpp.cpp_array_expr(sdfg, memlet, with_brackets=False, codegen=self._frame)
expr = (cpp.cpp_array_expr(sdfg, memlet, with_brackets=False, codegen=self._frame, deferred_size_names=deferred_size_names)
if var_type in [DefinedType.Pointer, DefinedType.StreamArray, DefinedType.ArrayInterface] else ptr)

if expr != ptr:
Expand Down Expand Up @@ -1425,7 +1440,7 @@ def memlet_definition(self,
if not memlet.dynamic and memlet.num_accesses == 1:
if not output:
if isinstance(desc, data.Stream) and desc.is_stream_array():
index = cpp.cpp_offset_expr(desc, memlet.subset)
index = cpp.cpp_offset_expr(desc, memlet.subset, deferred_size_names=deferred_size_names)
expr = f"{memlet.data}[{index}]"
result += f'{memlet_type} {local_name} = ({expr}).pop();'
defined = DefinedType.Scalar
Expand Down
30 changes: 15 additions & 15 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def covers(self, other):
return False

return True

def covers_precise(self, other):
""" Returns True if self contains all the elements in other. """

Expand Down Expand Up @@ -734,7 +734,7 @@ def compose(self, other):
def squeeze(self, ignore_indices: Optional[List[int]] = None, offset: bool = True) -> List[int]:
"""
Removes size-1 ranges from the subset and returns a list of dimensions that remain.
For example, ``[i:i+10, j]`` will change the range to ``[i:i+10]`` and return ``[0]``.
If ``offset`` is True, the subset will become ``[0:10]``.
Expand Down Expand Up @@ -770,7 +770,7 @@ def squeeze(self, ignore_indices: Optional[List[int]] = None, offset: bool = Tru

def unsqueeze(self, axes: Sequence[int]) -> List[int]:
""" Adds 0:1 ranges to the subset, in the indices contained in axes.
The method is mostly used to restore subsets that had their length-1
ranges removed (i.e., squeezed subsets). Hence, the method is
called 'unsqueeze'.
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def squeeze(self, ignore_indices=None):

def unsqueeze(self, axes: Sequence[int]) -> List[int]:
""" Adds zeroes to the subset, in the indices contained in axes.
The method is mostly used to restore subsets that had their
zero-indices removed (i.e., squeezed subsets). Hence, the method is
called 'unsqueeze'.
Expand Down Expand Up @@ -1112,7 +1112,7 @@ def __init__(self, subset):
self.subset_list = [subset]

def covers(self, other):
"""
"""
Returns True if this SubsetUnion covers another subset (using a bounding box).
If other is another SubsetUnion then self and other will
only return true if self is other. If other is a different type of subset
Expand All @@ -1128,13 +1128,13 @@ def covers(self, other):
return False
else:
return any(s.covers(other) for s in self.subset_list)

def covers_precise(self, other):
"""
"""
Returns True if this SubsetUnion covers another
subset. If other is another SubsetUnion then self and other will
only return true if self is other. If other is a different type of subset
true is returned when one of the subsets in self is equal to other
true is returned when one of the subsets in self is equal to other
"""

if isinstance(other, SubsetUnion):
Expand All @@ -1154,7 +1154,7 @@ def __str__(self):
string += " "
string += subset.__str__()
return string

def dims(self):
if not self.subset_list:
return 0
Expand All @@ -1178,7 +1178,7 @@ def free_symbols(self) -> Set[str]:
for subset in self.subset_list:
result |= subset.free_symbols
return result

def replace(self, repl_dict):
for subset in self.subset_list:
subset.replace(repl_dict)
Expand All @@ -1192,15 +1192,15 @@ def num_elements(self):
min = subset.num_elements()
except:
continue

return min



def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType,
bre: symbolic.SymbolicType):
"""
Special cases of subset unions. If case found, returns pair of
"""
Special cases of subset unions. If case found, returns pair of
(min,max), otherwise returns None.
"""
if are + 1 == brb:
Expand Down Expand Up @@ -1267,7 +1267,7 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
union.
:param subset_a: The first subset.
:param subset_b: The second subset.
:return: A Subset object whose size is at least the union of the two
Expand Down Expand Up @@ -1303,7 +1303,7 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:


def list_union(subset_a: Subset, subset_b: Subset) -> Subset:
"""
"""
Returns the union of two Subset lists.
:param subset_a: The first subset.
Expand Down
Loading

0 comments on commit b956142

Please sign in to comment.