diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 238d0b72c7..8388cce250 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1341,6 +1341,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: defined_syms |= set(self.constants_prop.keys()) + # Add used symbols from init and exit code + for code in self.init_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + for code in self.exit_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + # Add free state symbols used_before_assignment = set() @@ -1472,7 +1478,7 @@ def init_signature(self, for_call=False, free_symbols=None) -> str: :param for_call: If True, returns arguments that can be used when calling the SDFG. """ # Get global free symbols scalar arguments - free_symbols = free_symbols or self.free_symbols + free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) return ", ".join( dt.Scalar(self.symbols[k]).as_arg(name=k, with_types=not for_call, for_call=for_call) for k in sorted(free_symbols) if not k.startswith('__dace'))