From 23cd988498504e3e493189bf8ae99d210083e4af Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 20 Sep 2024 21:45:07 +0100 Subject: [PATCH] Apply fusions correctly when fused ops are chained Given a graph structure like: ``` A -> B -> C -> D ``` When fusions are generated for `A -> B` and `C -> D`, the input for "C" needs to be replaced by the output of the fused `A -> B`. --- src/optimize.rs | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/optimize.rs b/src/optimize.rs index 1fdee464..5290c8c9 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::error::Error; use std::fmt::{Display, Formatter}; @@ -142,8 +143,24 @@ impl GraphMutator { } } - for fusion in fusions { - fusion.apply(self); + // Map of old_output_id => new_output_id for subgraphs that have been + // replaced by fusions. + let mut replaced_ids = HashMap::::new(); + + for mut fusion in fusions { + // Replace input IDs which match output IDs of previously applied + // fusions. + for input_id in fusion.input_ids.iter_mut() { + if let Some(input_id) = input_id { + if let Some(replacement_id) = replaced_ids.get(input_id) { + *input_id = *replacement_id; + } + } + } + + let (old_output_id, new_output_id) = fusion.apply(self); + + replaced_ids.insert(old_output_id, new_output_id); } } @@ -230,7 +247,9 @@ impl Fusion { /// /// This adds the fused operator to the graph and replaces references to /// the original output nodes with the fused operator's outputs. - fn apply(self, graph: &mut GraphMutator) { + /// + /// Returns a tuple of `(old_output_id, new_output_id)`. + fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) { let Fusion { name, fused_op, @@ -240,6 +259,7 @@ impl Fusion { let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids); graph.replace_value(old_output_id, fused_op_output_id); + (old_output_id, fused_op_output_id) } }