diff --git a/src/optimize.rs b/src/optimize.rs index 1fdee464..b0d8dbf4 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::error::Error; use std::fmt::{Display, Formatter}; @@ -142,8 +143,22 @@ impl GraphMutator { } } - for fusion in fusions { - fusion.apply(self); + // Map of old_output_id => new_output_id for subgraphs that have been + // replaced by fusions. + let mut replaced_ids = HashMap::::new(); + + for mut fusion in fusions { + // Replace input IDs which match output IDs of previously applied + // fusions. + for input_id in fusion.input_ids.iter_mut().flatten() { + if let Some(replacement_id) = replaced_ids.get(input_id) { + *input_id = *replacement_id; + } + } + + let (old_output_id, new_output_id) = fusion.apply(self); + + replaced_ids.insert(old_output_id, new_output_id); } } @@ -230,7 +245,9 @@ impl Fusion { /// /// This adds the fused operator to the graph and replaces references to /// the original output nodes with the fused operator's outputs. - fn apply(self, graph: &mut GraphMutator) { + /// + /// Returns a tuple of `(old_output_id, new_output_id)`. + fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) { let Fusion { name, fused_op, @@ -240,6 +257,7 @@ impl Fusion { let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids); graph.replace_value(old_output_id, fused_op_output_id); + (old_output_id, fused_op_output_id) } } @@ -746,6 +764,31 @@ mod tests { assert_eq!(op.name(), Some("mul")); } + #[test] + fn test_chained_fused_ops() { + let mut graph = Graph::new(); + + // Add two consecutive decomposed Silu operations + let input = graph.add_value(None, None); + let (_, sigmoid_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[input]); + let (_, mul_out) = graph.add_simple_op("mul", Mul {}, &[input, sigmoid_out]); + let (_, sigmoid_2_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[mul_out]); + let (_, mul_2_out) = graph.add_simple_op("mul", Mul {}, &[mul_out, sigmoid_2_out]); + graph.set_input_ids(&[input]); + graph.set_output_ids(&[mul_2_out]); + + let graph = optimize_graph(graph).unwrap(); + + // Check that both ops were fused. This requires that the inputs to the + // second group of fused nodes are updated after fusing the first. + let (_, fused_op) = graph.get_source_node(graph.output_ids()[0]).unwrap(); + assert_eq!(fused_op.operator().name(), "Silu"); + let (_, fused_op_2) = graph + .get_source_node(fused_op.input_ids()[0].unwrap()) + .unwrap(); + assert_eq!(fused_op_2.operator().name(), "Silu"); + } + #[test] fn test_fuse_gelu() { let mut graph = Graph::new();