Skip to content

Commit

Permalink
Merge pull request #369 from robertknight/chained-opts-fusion
Browse files Browse the repository at this point in the history
Apply fusions correctly when fused ops are chained
  • Loading branch information
robertknight authored Sep 21, 2024
2 parents e276e29 + 100897a commit 3463b62
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions src/optimize.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Display, Formatter};

Expand Down Expand Up @@ -142,8 +143,22 @@ impl GraphMutator {
}
}

for fusion in fusions {
fusion.apply(self);
// Map of old_output_id => new_output_id for subgraphs that have been
// replaced by fusions.
let mut replaced_ids = HashMap::<NodeId, NodeId>::new();

for mut fusion in fusions {
// Replace input IDs which match output IDs of previously applied
// fusions.
for input_id in fusion.input_ids.iter_mut().flatten() {
if let Some(replacement_id) = replaced_ids.get(input_id) {
*input_id = *replacement_id;
}
}

let (old_output_id, new_output_id) = fusion.apply(self);

replaced_ids.insert(old_output_id, new_output_id);
}
}

Expand Down Expand Up @@ -230,7 +245,9 @@ impl Fusion {
///
/// This adds the fused operator to the graph and replaces references to
/// the original output nodes with the fused operator's outputs.
fn apply(self, graph: &mut GraphMutator) {
///
/// Returns a tuple of `(old_output_id, new_output_id)`.
fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) {
let Fusion {
name,
fused_op,
Expand All @@ -240,6 +257,7 @@ impl Fusion {

let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids);
graph.replace_value(old_output_id, fused_op_output_id);
(old_output_id, fused_op_output_id)
}
}

Expand Down Expand Up @@ -746,6 +764,31 @@ mod tests {
assert_eq!(op.name(), Some("mul"));
}

#[test]
fn test_chained_fused_ops() {
let mut graph = Graph::new();

// Add two consecutive decomposed Silu operations
let input = graph.add_value(None, None);
let (_, sigmoid_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[input]);
let (_, mul_out) = graph.add_simple_op("mul", Mul {}, &[input, sigmoid_out]);
let (_, sigmoid_2_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[mul_out]);
let (_, mul_2_out) = graph.add_simple_op("mul", Mul {}, &[mul_out, sigmoid_2_out]);
graph.set_input_ids(&[input]);
graph.set_output_ids(&[mul_2_out]);

let graph = optimize_graph(graph).unwrap();

// Check that both ops were fused. This requires that the inputs to the
// second group of fused nodes are updated after fusing the first.
let (_, fused_op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(fused_op.operator().name(), "Silu");
let (_, fused_op_2) = graph
.get_source_node(fused_op.input_ids()[0].unwrap())
.unwrap();
assert_eq!(fused_op_2.operator().name(), "Silu");
}

#[test]
fn test_fuse_gelu() {
let mut graph = Graph::new();
Expand Down

0 comments on commit 3463b62

Please sign in to comment.