Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Nov 13, 2023
1 parent 8eeb622 commit 15fb33c
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 79 deletions.
6 changes: 4 additions & 2 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def copy_expr(
packed_types=False,
):
data_desc = sdfg.arrays[data_name]
# TODO: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
# NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
# TODO: Study this when changing Structures to be (optionally?) non-pointers.
tokens = data_name.split('.')
if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure):
name = data_name.replace('.', '->')
Expand Down Expand Up @@ -585,7 +586,8 @@ def cpp_array_expr(sdfg,
desc = (sdfg.arrays[memlet.data] if referenced_array is None else referenced_array)
offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices)

# TODO: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
# NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
# TODO: Study this when changing Structures to be (optionally?) non-pointers.
tokens = memlet.data.split('.')
if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure):
name = memlet.data.replace('.', '->')
Expand Down
3 changes: 3 additions & 0 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d

tokens = node.data.split('.')
top_desc = sdfg.arrays[tokens[0]]
# NOTE: Assuming here that all Structure members share transient/storage/lifetime properties.
# TODO: Study what is needed in the DaCe stuck to ensure this assumption is correct.
top_transient = top_desc.transient
top_storage = top_desc.storage
top_lifetime = top_desc.lifetime
Expand Down Expand Up @@ -644,6 +646,7 @@ def _emit_copy(
#############################################
# Corner cases

# NOTE: This looks obsolete but keeping it commented out in case tests fail.
# Writing one index
# if (isinstance(memlet.subset, subsets.Indices) and memlet.wcr is None
# and self._dispatcher.defined_vars.get(vconn)[0] == DefinedType.Scalar):
Expand Down
10 changes: 2 additions & 8 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,6 @@ def preprocess(self, sdfg: SDFG) -> None:
and node.map.schedule in (dtypes.ScheduleType.GPU_Device, dtypes.ScheduleType.GPU_Persistent)):
if state.parent not in shared_transients:
shared_transients[state.parent] = state.parent.shared_transients()
# sgraph = state.scope_subgraph(node)
# used_symbols = sgraph.used_symbols(all_symbols=False)
# arglist = sgraph.arglist(defined_syms, shared_transients[state.parent])
# arglist = {k: v for k, v in arglist.items() if not k in defined_syms or k in used_symbols}
# self._arglists[node] = arglist
# TODO/NOTE: Did we change defined_syms?
self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.parent])

def _compute_pool_release(self, top_sdfg: SDFG):
Expand Down Expand Up @@ -1029,11 +1023,11 @@ def _emit_copy(self, state_id, src_node, src_storage, dst_node, dst_storage, dst
if issubclass(node_dtype.type, ctypes.Structure):
callsite_stream.write('for (size_t __idx = 0; __idx < {arrlen}; ++__idx) '
'{{'.format(arrlen=array_length))
# for field_name, field_type in node_dtype._data.items():
# TODO: Study further when tackling Structures on GPU.
for field_name, field_type in node_dtype._typeclass.fields.items():
if isinstance(field_type, dtypes.pointer):
tclass = field_type.type
# length = node_dtype._length[field_name]

length = node_dtype._typeclass._length[field_name]
size = 'sizeof({})*{}[__idx].{}'.format(dtypes._CTYPES[tclass], str(src_node), length)
callsite_stream.write('DACE_GPU_CHECK({backend}Malloc(&{dst}[__idx].{fname}, '
Expand Down
3 changes: 2 additions & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,8 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
access_instances[sdfg.sdfg_id] = instances

for sdfg, name, desc in top_sdfg.arrays_recursive(include_nested_data=True):
# NOTE/TODO: Temporary fix for nested data not having the same attributes as their parent
# NOTE: Assuming here that all Structure members share transient/storage/lifetime properties.
# TODO: Study what is needed in the DaCe stuck to ensure this assumption is correct.
top_desc = sdfg.arrays[name.split('.')[0]]
top_transient = top_desc.transient
top_storage = top_desc.storage
Expand Down
47 changes: 23 additions & 24 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Shape = Union[ShapeTuple, ShapeList]
DependencyType = Dict[str, Tuple[SDFGState, Union[Memlet, nodes.Tasklet], Tuple[int]]]


if sys.version_info < (3, 8):
_simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
BytesConstant = ast.Bytes
Expand All @@ -65,15 +64,13 @@
NumConstant = ast.Constant
StrConstant = ast.Constant


if sys.version_info < (3, 9):
Index = ast.Index
ExtSlice = ast.ExtSlice
else:
Index = type(None)
ExtSlice = type(None)


if sys.version_info < (3, 12):
TypeAlias = type(None)
else:
Expand Down Expand Up @@ -452,10 +449,11 @@ def add_indirection_subgraph(sdfg: SDFG,
for i, r in enumerate(memlet.subset):
if i in nonsqz_dims:
mapped_rng.append(r)
ind_entry, ind_exit = graph.add_map(
'indirection', {'__i%d' % i: '%s:%s+1:%s' % (s, e, t)
for i, (s, e, t) in enumerate(mapped_rng)},
debuginfo=pvisitor.current_lineinfo)
ind_entry, ind_exit = graph.add_map('indirection', {
'__i%d' % i: '%s:%s+1:%s' % (s, e, t)
for i, (s, e, t) in enumerate(mapped_rng)
},
debuginfo=pvisitor.current_lineinfo)
inp_base_path.insert(0, ind_entry)
out_base_path.append(ind_exit)

Expand Down Expand Up @@ -1339,9 +1337,10 @@ def defined(self):
result.update(self.sdfg.arrays)

# MPI-related stuff
result.update(
{k: self.sdfg.process_grids[v]
for k, v in self.variables.items() if v in self.sdfg.process_grids})
result.update({
k: self.sdfg.process_grids[v]
for k, v in self.variables.items() if v in self.sdfg.process_grids
})
try:
from mpi4py import MPI
result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)})
Expand Down Expand Up @@ -3218,8 +3217,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if (not is_return and isinstance(target, ast.Name) and true_name and not op
and not isinstance(true_array, data.Scalar) and not (true_array.shape == (1, ))):
if true_name in self.views:
if result in self.sdfg.arrays and self.views[true_name] == (
result, Memlet.from_array(result, self.sdfg.arrays[result])):
if result in self.sdfg.arrays and self.views[true_name] == (result,
Memlet.from_array(
result, self.sdfg.arrays[result])):
continue
else:
raise DaceSyntaxError(self, target, 'Cannot reassign View "{}"'.format(name))
Expand Down Expand Up @@ -3762,14 +3762,12 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no
from dace.frontend.python.parser import infer_symbols_from_datadescriptor

# Map internal SDFG symbols by adding keyword arguments
# symbols = set(sdfg.symbols.keys())
# symbols = sdfg.free_symbols
symbols = sdfg.used_symbols(all_symbols=False)
try:
mapping = infer_symbols_from_datadescriptor(
sdfg, {k: self.sdfg.arrays[v]
for k, v in args if v in self.sdfg.arrays},
set(sym.arg for sym in node.keywords if sym.arg in symbols))
mapping = infer_symbols_from_datadescriptor(sdfg, {
k: self.sdfg.arrays[v]
for k, v in args if v in self.sdfg.arrays
}, set(sym.arg for sym in node.keywords if sym.arg in symbols))
except ValueError as ex:
raise DaceSyntaxError(self, node, str(ex))
if len(mapping) == 0: # Default to same-symbol mapping
Expand Down Expand Up @@ -4733,7 +4731,7 @@ def visit_Dict(self, node: ast.Dict):
def visit_Lambda(self, node: ast.Lambda):
# Return a string representation of the function
return astutils.unparse(node)

def visit_TypeAlias(self, node: TypeAlias):
raise NotImplementedError('Type aliases are not supported in DaCe')

Expand Down Expand Up @@ -4922,11 +4920,12 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr):
# NOTE: We convert the subsets to string because keeping the original symbolic information causes
# equality check failures, e.g., in LoopToMap.
self.last_state.add_nedge(
rnode, wnode, Memlet(data=array,
subset=str(expr.subset),
other_subset=str(other_subset),
volume=expr.accesses,
wcr=expr.wcr))
rnode, wnode,
Memlet(data=array,
subset=str(expr.subset),
other_subset=str(other_subset),
volume=expr.accesses,
wcr=expr.wcr))
return tmp

def _parse_subscript_slice(self,
Expand Down
92 changes: 57 additions & 35 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def _define_local_structure(pv: ProgramVisitor,
""" Defines a local structure in a DaCe program. """
name = sdfg.temp_data_name()
desc = copy.deepcopy(dtype)
desc.transient=True
desc.storage=storage
desc.lifetime=lifetime
desc.transient = True
desc.storage = storage
desc.lifetime = lifetime
sdfg.add_datadesc(name, desc)
pv.variables[name] = name
return name
Expand Down Expand Up @@ -318,16 +318,20 @@ def _numpy_full(pv: ProgramVisitor,

if is_data:
state.add_mapped_tasklet(
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)},
'_numpy_full_', {
"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)
},
dict(__inp=dace.Memlet(data=fill_value, subset='0')),
"__out = __inp",
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
else:
state.add_mapped_tasklet(
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)}, {},
'_numpy_full_', {
"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)
}, {},
"__out = {}".format(fill_value),
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
Expand Down Expand Up @@ -447,8 +451,10 @@ def _numpy_flip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, axis
inpidx = ','.join([f'__i{i}' for i in range(ndim)])
outidx = ','.join([f'{s} - __i{i} - 1' if a else f'__i{i}' for i, (a, s) in enumerate(zip(axis, desc.shape))])
state.add_mapped_tasklet(name="_numpy_flip_",
map_ranges={f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)},
map_ranges={
f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)
},
inputs={'__inp': Memlet(f'{arr}[{inpidx}]')},
code='__out = __inp',
outputs={'__out': Memlet(f'{arr_copy}[{outidx}]')},
Expand Down Expand Up @@ -518,8 +524,10 @@ def _numpy_rot90(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, k=1

outidx = ','.join(out_indices)
state.add_mapped_tasklet(name="_rot90_",
map_ranges={f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)},
map_ranges={
f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)
},
inputs={'__inp': Memlet(f'{arr}[{inpidx}]')},
code='__out = __inp',
outputs={'__out': Memlet(f'{arr_copy}[{outidx}]')},
Expand Down Expand Up @@ -623,8 +631,10 @@ def _elementwise(pv: 'ProgramVisitor',
else:
state.add_mapped_tasklet(
name="_elementwise_",
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)},
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)
},
inputs={'__inp': Memlet.simple(in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
code=code,
outputs={'__out': Memlet.simple(out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
Expand Down Expand Up @@ -674,8 +684,10 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype:
else:
state.add_mapped_tasklet(
name=func,
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)},
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)
},
inputs={'__inp': Memlet.simple(inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
code='__out = {f}(__inp)'.format(f=func),
outputs={'__out': Memlet.simple(outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
Expand Down Expand Up @@ -1024,22 +1036,27 @@ def _argminmax(pv: ProgramVisitor,
code = "__init = _val_and_idx(val={}, idx=-1)".format(
dtypes.min_value(a_arr.dtype) if func == 'max' else dtypes.max_value(a_arr.dtype))

nest.add_state().add_mapped_tasklet(
name="_arg{}_convert_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
inputs={},
code=code,
outputs={
'__init': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
},
external_edges=True)
nest.add_state().add_mapped_tasklet(name="_arg{}_convert_".format(func),
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis
},
inputs={},
code=code,
outputs={
'__init':
Memlet.simple(
reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
},
external_edges=True)

nest.add_state().add_mapped_tasklet(
name="_arg{}_reduce_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape)},
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape)
},
inputs={'__in': Memlet.simple(a, ','.join('__i%d' % i for i in range(len(a_arr.shape))))},
code="__out = _val_and_idx(idx={}, val=__in)".format("__i%d" % axis),
outputs={
Expand All @@ -1059,8 +1076,10 @@ def _argminmax(pv: ProgramVisitor,

nest.add_state().add_mapped_tasklet(
name="_arg{}_extract_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis
},
inputs={
'__in': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
Expand Down Expand Up @@ -1183,9 +1202,10 @@ def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str):
opcode = 'not'

name, _ = sdfg.add_temp_transient(arr1.shape, restype, arr1.storage)
state.add_mapped_tasklet("_%s_" % opname, {'__i%d' % i: '0:%s' % s
for i, s in enumerate(arr1.shape)},
{'__in1': Memlet.simple(op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))},
state.add_mapped_tasklet("_%s_" % opname, {
'__i%d' % i: '0:%s' % s
for i, s in enumerate(arr1.shape)
}, {'__in1': Memlet.simple(op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))},
'__out = %s __in1' % opcode,
{'__out': Memlet.simple(name, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))},
external_edges=True)
Expand Down Expand Up @@ -4709,8 +4729,10 @@ def _cupy_full(pv: ProgramVisitor,
name, _ = sdfg.add_temp_transient(shape, dtype, storage=dtypes.StorageType.GPU_Global)

state.add_mapped_tasklet(
'_cupy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)}, {},
'_cupy_full_', {
"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)
}, {},
"__out = {}".format(fill_value),
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
Expand Down
Loading

0 comments on commit 15fb33c

Please sign in to comment.