Skip to content

Commit

Permalink
Apply fusions correctly when fused ops are chained
Browse files Browse the repository at this point in the history
Given a graph structure like:

```
A -> B -> C -> D
```

When fusions are generated for `A -> B` and `C -> D`, the input for "C" needs
to be replaced by the output of the fused `A -> B`.
  • Loading branch information
robertknight committed Sep 20, 2024
1 parent e276e29 commit 23cd988
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions src/optimize.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Display, Formatter};

Expand Down Expand Up @@ -142,8 +143,24 @@ impl GraphMutator {
}
}

for fusion in fusions {
fusion.apply(self);
// Map of old_output_id => new_output_id for subgraphs that have been
// replaced by fusions.
let mut replaced_ids = HashMap::<NodeId, NodeId>::new();

for mut fusion in fusions {
// Replace input IDs which match output IDs of previously applied
// fusions.
for input_id in fusion.input_ids.iter_mut() {
if let Some(input_id) = input_id {
if let Some(replacement_id) = replaced_ids.get(input_id) {
*input_id = *replacement_id;
}
}
}

let (old_output_id, new_output_id) = fusion.apply(self);

replaced_ids.insert(old_output_id, new_output_id);
}
}

Expand Down Expand Up @@ -230,7 +247,9 @@ impl Fusion {
///
/// This adds the fused operator to the graph and replaces references to
/// the original output nodes with the fused operator's outputs.
fn apply(self, graph: &mut GraphMutator) {
///
/// Returns a tuple of `(old_output_id, new_output_id)`.
fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) {
let Fusion {
name,
fused_op,
Expand All @@ -240,6 +259,7 @@ impl Fusion {

let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids);
graph.replace_value(old_output_id, fused_op_output_id);
(old_output_id, fused_op_output_id)
}
}

Expand Down

0 comments on commit 23cd988

Please sign in to comment.