Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Name Validation #1661

Merged
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8bafffa
Added checks in the `add_{symbol, datadesc, constant, subarray, rdist…
philip-paul-mueller Sep 18, 2024
3ce8454
The nested SDFG now checks if the input connector is a symbol.
philip-paul-mueller Sep 18, 2024
b1d41ea
Updated teh SDFG validation.
philip-paul-mueller Sep 18, 2024
f82ea7b
Forget the constants.
philip-paul-mueller Sep 18, 2024
5d2b3b6
It is allowed to update constants, so we have to behave differently.
philip-paul-mueller Sep 18, 2024
7e75894
If a constant is created and an entity with the same name already exi…
philip-paul-mueller Sep 18, 2024
52ee269
Some general improvements.
philip-paul-mueller Sep 19, 2024
55830cc
Updated the `add_constant()` function.
philip-paul-mueller Sep 19, 2024
2204b16
The past is too strong.
philip-paul-mueller Sep 19, 2024
f9d446c
Fixed some naming issue.
philip-paul-mueller Sep 19, 2024
eea5e3a
Why is this so hard?
philip-paul-mueller Sep 19, 2024
4cb2211
Let's hope that this does the trick.
philip-paul-mueller Sep 19, 2024
53d31b7
New try to prevent multi name clashes.
philip-paul-mueller Sep 20, 2024
a962e05
Sometimes.
philip-paul-mueller Sep 20, 2024
f289f85
The `find_new_name` was there twice, probably bad merge.
philip-paul-mueller Sep 20, 2024
c89cd32
Found the offending code.
philip-paul-mueller Sep 20, 2024
a2c8c5b
Merge branch 'master' into better_symbol_checking
philip-paul-mueller Sep 20, 2024
da11849
The `_redistribute()` function still created pseudo scalar with the s…
philip-paul-mueller Sep 23, 2024
fc27a84
Patched the Python interpreter to properly handle process grid and such.
philip-paul-mueller Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,6 @@ def create_datadescriptor(obj, no_custom_desc=False):
'adaptor method to the type hint or object itself.')


def find_new_name(name: str, existing_names: Sequence[str]) -> str:
"""
Returns a name that matches the given ``name`` as a prefix, but does not
already exist in the given existing name set. The behavior is typically
to append an underscore followed by a unique (increasing) number. If the
name does not already exist in the set, it is returned as-is.

:param name: The given name to find.
:param existing_names: The set of existing names.
:return: A new name that is not in existing_names.
"""
if name not in existing_names:
return name
cur_offset = 0
new_name = name + '_' + str(cur_offset)
while new_name in existing_names:
cur_offset += 1
new_name = name + '_' + str(cur_offset)
return new_name


def _prod(sequence):
return functools.reduce(lambda a, b: a * b, sequence, 1)

Expand Down
26 changes: 13 additions & 13 deletions dace/frontend/common/distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: Shape
state.add_node(tasklet)

# Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
_, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True)
wnode = state.add_write(pgrid_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal))
scal_name, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

return pgrid_name

Expand Down Expand Up @@ -97,9 +97,9 @@ def _cart_sub(pv: 'ProgramVisitor',
state.add_node(tasklet)

# Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
_, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True)
wnode = state.add_write(pgrid_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal))
scal_name, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

return pgrid_name

Expand Down Expand Up @@ -196,7 +196,7 @@ def _intracomm_bcast(pv: 'ProgramVisitor',
if comm_obj == MPI.COMM_WORLD:
return _bcast(pv, sdfg, state, buffer, root)
# NOTE: Highly experimental
sdfg.add_scalar(comm_name, dace.int32)
scal_name, _ = sdfg.add_scalar(comm_name, dace.int32, find_new_name=True)
return _bcast(pv, sdfg, state, buffer, root, fcomm=comm_name)


Expand Down Expand Up @@ -941,9 +941,9 @@ def _subarray(pv: ProgramVisitor,
state.add_node(tasklet)

# Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
_, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True)
wnode = state.add_write(subarray_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(subarray_name, scal))
scal_name, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

return subarray_name

Expand Down Expand Up @@ -1078,9 +1078,9 @@ def _redistribute(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, in_buffer: s
f'int* {rdistrarray_name}_self_size;'
])
state.add_node(tasklet)
_, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True)
wnode = state.add_write(rdistrarray_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(rdistrarray_name, scal))
scal_name, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

libnode = Redistribute('_Redistribute_', rdistrarray_name)

Expand Down
9 changes: 6 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
internally_used_symbols = self.sdfg.used_symbols(all_symbols=False)
keys_to_use &= internally_used_symbols

# Translate the internal symbols back to their external counterparts.
free_syms |= set().union(*(map(str,
pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items()
if k in keys_to_use))
Expand Down Expand Up @@ -662,6 +663,10 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
Expand Down Expand Up @@ -795,10 +800,8 @@ def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]:
for p, rng in zip(self._map.params, self._map.range):
result[p] = dtypes.result_type_of(infer_expr_type(rng[0], symbols), infer_expr_type(rng[1], symbols))

# Add dynamic inputs
# Handle the dynamic map ranges.
dyn_inputs = set(c for c in self.in_connectors if not c.startswith('IN_'))

# Try to get connector type from connector
for e in state.in_edges(self):
if e.dst_conn in dyn_inputs:
result[e.dst_conn] = (self.in_connectors[e.dst_conn] or sdfg.arrays[e.data.data].dtype)
Expand Down
159 changes: 106 additions & 53 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,17 +746,32 @@ def replace_dict(self,

super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys)

def add_symbol(self, name, stype):
def add_symbol(self, name, stype, find_new_name: bool = False):
""" Adds a symbol to the SDFG.

:param name: Symbol name.
:param stype: Symbol type.
:param find_new_name: Find a new name.
"""
if name in self.symbols:
raise FileExistsError('Symbol "%s" already exists in SDFG' % name)
if find_new_name:
name = self._find_new_name(name)
else:
# We do not check for data constant, because there is a link between the constants and
# the data descriptors.
if name in self.symbols:
raise FileExistsError(f'Symbol "{name}" already exists in SDFG')
if name in self.arrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a data descriptor.')
if name in self._subarrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a ProcessGrid.')
if not isinstance(stype, dtypes.typeclass):
stype = dtypes.dtype_to_typeclass(stype)
self.symbols[name] = stype
return name

def remove_symbol(self, name):
""" Removes a symbol from the SDFG.
Expand Down Expand Up @@ -1159,14 +1174,23 @@ def cast(dtype: dt.Data, value: Any):
return result

def add_constant(self, name: str, value: Any, dtype: dt.Data = None):
""" Adds/updates a new compile-time constant to this SDFG. A constant
may either be a scalar or a numpy ndarray thereof.
"""
Adds/updates a new compile-time constant to this SDFG.

:param name: The name of the constant.
:param value: The constant value.
:param dtype: Optional data type of the symbol, or None to deduce
automatically.
A constant may either be a scalar or a numpy ndarray thereof. It is not an
error if there is already a symbol or an array with the same name inside
the SDFG. However, the data descriptors must refer to the same type.

:param name: The name of the constant.
:param value: The constant value.
:param dtype: Optional data type of the symbol, or None to deduce automatically.
"""
if name in self._subarrays:
raise FileExistsError(f'Can not create constant "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create constant "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create constant "{name}", the name is used by a ProcessGrid.')
self.constants_prop[name] = (dtype or dt.create_datadescriptor(value), value)

@property
Expand Down Expand Up @@ -1598,36 +1622,44 @@ def _find_new_name(self, name: str):
""" Tries to find a new name by adding an underscore and a number. """

names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys()
| self._rdistrarrays.keys())
| self._rdistrarrays.keys() | self.symbols.keys())
return dt.find_new_name(name, names)

def is_name_used(self, name: str) -> bool:
""" Checks if `name` is already used inside the SDFG."""
if name in self._arrays:
return True
if name in self.symbols:
return True
if name in self.constants_prop:
return True
if name in self._pgrids:
return True
if name in self._subarrays:
return True
if name in self._rdistrarrays:
return True
return False

def is_name_free(self, name: str) -> bool:
""" Test if `name` is free, i.e. is not used by anything else."""
return not self.is_name_used(name)

def find_new_constant(self, name: str):
"""
Tries to find a new constant name by adding an underscore and a number.
Tries to find a new name for a constant.
"""
constants = self.constants
if name not in constants:
if self.is_name_free(name):
return name

index = 0
while (name + ('_%d' % index)) in constants:
index += 1

return name + ('_%d' % index)
return self._find_new_name(name)

def find_new_symbol(self, name: str):
"""
Tries to find a new symbol name by adding an underscore and a number.
"""
symbols = self.symbols
if name not in symbols:
if self.is_name_free(name):
return name

index = 0
while (name + ('_%d' % index)) in symbols:
index += 1

return name + ('_%d' % index)
return self._find_new_name(name)

def add_array(self,
name: str,
Expand Down Expand Up @@ -1856,13 +1888,14 @@ def add_transient(self,

def temp_data_name(self):
""" Returns a temporary data descriptor name that can be used in this SDFG. """

name = '__tmp%d' % self._temp_transients
while name in self._arrays:

# NOTE: Consider switching to `_find_new_name`
# The frontend seems to access this variable directly.
while self.is_name_used(name):
self._temp_transients += 1
name = '__tmp%d' % self._temp_transients
self._temp_transients += 1

return name

def add_temp_transient(self,
Expand Down Expand Up @@ -1917,29 +1950,47 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str
"""
if not isinstance(name, str):
raise TypeError("Data descriptor name must be a string. Got %s" % type(name).__name__)
# If exists, fail
while name in self._arrays:
if find_new_name:
name = self._find_new_name(name)
else:
raise NameError(f'Array or Stream with name "{name}" already exists in SDFG')
# NOTE: Remove illegal characters, such as dots. Such characters may be introduced when creating views to
# members of Structures.
name = name.replace('.', '_')
assert name not in self._arrays
self._arrays[name] = datadesc

def _add_symbols(desc: dt.Data):
if find_new_name:
# These characters might be introduced through the creation of views to members
# of strictures.
# NOTES: If `find_new_name` is `True` and the name (understood as a sequence of
# any characters) is not used, i.e. `assert self.is_name_free(name)`, then it
# is still "cleaned", i.e. dots are replaced with underscores. However, if
# `find_new_name` is `False` then this cleaning is not applied and it is possible
# to create names that are formally invalid. The above code reproduces the exact
# same behaviour and is maintained for compatibility. This behaviour is
# triggered by tests/python_frontend/structures/structure_python_test.py::test_rgf`.
name = self._find_new_name(name)
name = name.replace('.', '_')
if self.is_name_used(name):
name = self._find_new_name(name)
else:
# We do not check for data constant, because there is a link between the constants and
# the data descriptors.
if name in self.arrays:
raise FileExistsError(f'Data descriptor "{name}" already exists in SDFG')
if name in self.symbols:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a symbol.')
if name in self._subarrays:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a ProcessGrid.')

def _add_symbols(sdfg: SDFG, desc: dt.Data):
if isinstance(desc, dt.Structure):
for v in desc.members.values():
if isinstance(v, dt.Data):
_add_symbols(v)
_add_symbols(sdfg, v)
for sym in desc.free_symbols:
if sym.name not in self.symbols:
self.add_symbol(sym.name, sym.dtype)
if sym.name not in sdfg.symbols:
sdfg.add_symbol(sym.name, sym.dtype)

# Add free symbols to the SDFG global symbol storage
_add_symbols(datadesc)
# Add the data descriptor to the SDFG and all symbols that are not yet known.
self._arrays[name] = datadesc
_add_symbols(self, datadesc)

return name

Expand Down Expand Up @@ -2044,9 +2095,10 @@ def add_subarray(self,
newshape.append(dace.symbolic.pystr_to_symbolic(s))
subshape = newshape

# No need to ensure unique test.
subarray_name = self._find_new_name('__subarray')
self._subarrays[subarray_name] = SubArray(subarray_name, dtype, shape, subshape, pgrid, correspondence)

self._subarrays[subarray_name] = SubArray(subarray_name, dtype, shape, subshape, pgrid, correspondence)
self.append_init_code(self._subarrays[subarray_name].init_code())
self.append_exit_code(self._subarrays[subarray_name].exit_code())

Expand All @@ -2060,12 +2112,13 @@ def add_rdistrarray(self, array_a: str, array_b: str):
:param array_b: Output sub-array descriptor.
:return: Name of the new redistribution descriptor.
"""
# No need to ensure unique test.
name = self._find_new_name('__rdistrarray')

rdistrarray_name = self._find_new_name('__rdistrarray')
self._rdistrarrays[rdistrarray_name] = RedistrArray(rdistrarray_name, array_a, array_b)
self.append_init_code(self._rdistrarrays[rdistrarray_name].init_code(self))
self.append_exit_code(self._rdistrarrays[rdistrarray_name].exit_code(self))
return rdistrarray_name
self._rdistrarrays[name] = RedistrArray(name, array_a, array_b)
self.append_init_code(self._rdistrarrays[name].init_code(self))
self.append_exit_code(self._rdistrarrays[name].exit_code(self))
return name

def add_loop(
self,
Expand Down
28 changes: 28 additions & 0 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,34 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
if len(blocks) != len(set([s.label for s in blocks])):
raise InvalidSDFGError('Found multiple blocks with the same name in ' + cfg.name, sdfg, None)

# Check the names of data descriptors and co.
seen_names: Set[str] = set()
for obj_names in [
sdfg.arrays.keys(), sdfg.symbols.keys(), sdfg._rdistrarrays.keys(), sdfg._subarrays.keys()
]:
if not seen_names.isdisjoint(obj_names):
raise InvalidSDFGError(
f'Found duplicated names: "{seen_names.intersection(obj_names)}". Please ensure '
'that the names of symbols, data descriptors, subarrays and rdistarrays are unique.', sdfg, None)
seen_names.update(obj_names)

# Ensure that there is a mentioning of constants in either the array or symbol.
for const_name, (const_type, _) in sdfg.constants_prop.items():
if const_name in sdfg.arrays:
if const_type != sdfg.arrays[const_name].dtype:
# This should actually be an error, but there is a lots of code that depends on it.
warnings.warn(
f'Mismatch between constant and data descriptor of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".')
elif const_name in sdfg.symbols:
if const_type != sdfg.symbols[const_name]:
# This should actually be an error, but there is a lots of code that depends on it.
warnings.warn(
f'Mismatch between constant and symobl type of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".')
else:
warnings.warn(f'Found constant "{const_name}" that does not refer to an array or a symbol.')

# Validate data descriptors
for name, desc in sdfg._arrays.items():
if id(desc) in references:
Expand Down
Loading