diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index a0cb08ea0c..11026f03d7 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -631,6 +631,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: integers_only = self.integers_only to_promote = find_promotable_scalars(sdfg, transients_only=transients_only, integers_only=integers_only) + promoted = set() if ignore: to_promote -= ignore if len(to_promote) == 0: @@ -640,7 +641,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] # Step 2: Assignment tasklets for node in scalar_nodes: - if state.in_degree(node) == 0: + if node.data in promoted or state.in_degree(node) == 0: continue in_edge = state.in_edges(node)[0] input = in_edge.src @@ -681,6 +682,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # Clean up all nodes after assignment was transferred new_state.remove_nodes_from(new_state.nodes()) + promoted.add(node.data) # Step 3: Scalar reads remove_scalar_reads(sdfg, {k: k for k in to_promote})