diff --git a/core/src/ops/matmul/optimized.rs b/core/src/ops/matmul/optimized.rs index 2f6bbd2ab0..5c7c74380b 100644 --- a/core/src/ops/matmul/optimized.rs +++ b/core/src/ops/matmul/optimized.rs @@ -12,7 +12,7 @@ use tract_linalg::mmm::panel_extract::PanelExtractInput; use tract_linalg::mmm::{ AsInputValue, EagerPackedInput, FusedSpec, MMMInputValue, MatMatMul, OutputStoreSpec, }; -use tract_linalg::{Scaler, BinOp}; +use tract_linalg::{BinOp, Scaler}; use tract_smallvec::ToSmallVec; #[derive(Clone, Debug)] @@ -565,7 +565,12 @@ impl TypedOp for OptMatMul { } } if let Some(op) = succ.op_as::() { - if op.binop.is::() && self.mmm.len() == 1 { + let in_1_fact = model.outlet_fact(succ.inputs[0])?; + let in_2_fact = model.outlet_fact(succ.inputs[1])?; + if op.binop.is::() + && self.mmm.len() == 1 + && in_1_fact.without_value() == in_2_fact.without_value() + { let other_slot = 1 - node.outputs[0].successors[0].slot; let other_input = succ.inputs[other_slot]; let other_input = patch.tap_model(model, other_input)?; @@ -583,6 +588,15 @@ impl TypedOp for OptMatMul { &[other_input], ); } + } else { + let mut binop = + if let Some(op) = op.binop.as_linalg_binop() { op } else { return Ok(None) }; + let flipped = succ.inputs[0].node == node.id; + if flipped { + binop = binop.flip(); + } + let other_outlet = succ.inputs[flipped as usize]; + return self.fuse_binary(model, node, patch, other_outlet, binop); } }; Ok(None)