Skip to content

Commit

Permalink
Better unicast fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Oct 18, 2024
1 parent 097fe74 commit 3716bde
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions core/src/ops/matmul/optimized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tract_linalg::mmm::panel_extract::PanelExtractInput;
use tract_linalg::mmm::{
AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec,
};
use tract_linalg::{Scaler, BinOp};
use tract_linalg::{BinOp, Scaler};
use tract_smallvec::ToSmallVec;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -565,7 +565,12 @@ impl TypedOp for OptMatMul {
}
}
if let Some(op) = succ.op_as::<ops::binary::OptBinUnicast>() {
if op.binop.is::<ops::math::Add>() && self.mmm.len() == 1 {
let in_1_fact = model.outlet_fact(succ.inputs[0])?;
let in_2_fact = model.outlet_fact(succ.inputs[1])?;
if op.binop.is::<ops::math::Add>()
&& self.mmm.len() == 1
&& in_1_fact.without_value() == in_2_fact.without_value()
{
let other_slot = 1 - node.outputs[0].successors[0].slot;
let other_input = succ.inputs[other_slot];
let other_input = patch.tap_model(model, other_input)?;
Expand All @@ -583,6 +588,15 @@ impl TypedOp for OptMatMul {
&[other_input],
);
}
} else {
let mut binop =
if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) };
let flipped = succ.inputs[0].node == node.id;
if flipped {
binop = binop.flip();
}
let other_outlet = succ.inputs[flipped as usize];
return self.fuse_binary(model, node, patch, other_outlet, binop);
}
};
Ok(None)
Expand Down

0 comments on commit 3716bde

Please sign in to comment.