diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 28431deeea..39335dd90d 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -581,6 +581,10 @@ def from_json(json_obj, context=None): return ret def used_symbols(self, all_symbols: bool) -> Set[str]: + free_syms = set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) + + keys_to_use = set(self.symbol_mapping.keys()) + free_syms = set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), *(map(str, @@ -589,8 +593,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Filter out unused internal symbols from symbol mapping if not all_symbols: internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) - free_syms &= internally_used_symbols - + keys_to_use &= internally_used_symbols + + free_syms |= set().union(*(map(str, + pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items() + if k in keys_to_use)) + return free_syms @property @@ -640,7 +648,7 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError('"%s" is a connector but its corresponding array is transient' % dname) - + # Validate inout connectors from dace.sdfg import utils # Avoids circular import inout_connectors = self.in_connectors.keys() & self.out_connectors.keys()