diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index 64674bc3..bf865539 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -115,6 +115,7 @@ class OperatorType(object): DequantizeLinear = 105 QuantizeLinear = 106 DynamicQuantizeLinear = 107 + MatMulInteger = 108 class RNNDirection(object): diff --git a/src/model_builder.rs b/src/model_builder.rs index 5d4d2b37..0b45d5a6 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -77,6 +77,7 @@ pub enum OpType<'a> { Log, LogSoftmax(LogSoftmax), MatMul, + MatMulInteger, Max, MaxPool(MaxPool), Mean, @@ -614,6 +615,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { } ), OpType::MatMul => op!(MatMul), + OpType::MatMulInteger => op!(MatMulInteger), OpType::Max => op!(Max), OpType::MaxPool(args) => op_with_attrs!(MaxPool, MaxPoolAttrs, { let pad_args = pad_args_from_padding(args.padding); diff --git a/src/op_registry.rs b/src/op_registry.rs index e5682617..ea12cc28 100644 --- a/src/op_registry.rs +++ b/src/op_registry.rs @@ -133,6 +133,7 @@ impl OpRegistry { register_op!(LogSoftmax); register_op!(LSTM); register_op!(MatMul); + register_op!(MatMulInteger); register_op!(Max); register_op!(MaxPool); register_op!(Mean); @@ -610,6 +611,7 @@ impl_read_op!(LSTM, attrs_as_lstmattrs, |attrs: sg::LSTMAttrs| { }) }); impl_read_op!(MatMul); +impl_read_op!(MatMulInteger); impl_read_op!(Max); impl_read_op!( MaxPool, diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 1a521270..e0af4b7a 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -5,6 +5,7 @@ use rten_tensor::{Tensor, TensorView}; use crate::check_dims; use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT}; +use crate::iter_util::range_chunks; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::layout::expand_to; use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList}; @@ -266,6 +267,117 @@ impl Operator for MatMul { } } +pub fn matmul_integer( + pool: &TensorPool, + a: TensorView, + b: TensorView, + a_zero_point: Option>, + b_zero_point: Option>, +) -> Result, OpError> { + if a.ndim() < 2 || b.ndim() < 2 { + return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions")); + } + + let a_rows = a.size(a.ndim() - 2); + let a_cols = a.size(a.ndim() - 1); + + let b_rows = b.size(b.ndim() - 2); + let b_cols = b.size(b.ndim() - 1); + + if a_cols != b_rows { + return Err(OpError::IncompatibleInputShapes( + "Columns of first matrix does not match rows of second matrix", + )); + } + + let a_prefix = &a.shape()[..a.ndim() - 2]; + let b_prefix = &b.shape()[..b.ndim() - 2]; + + let out_prefix = broadcast_shapes(a_prefix, b_prefix) + .ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?; + let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat(); + + let mut output = Tensor::::uninit_in(pool, out_shape); + if output.is_empty() { + // nb. We don't need to alloc from the pool here, since the buffer + // is already empty. + return Ok(Tensor::zeros(out_shape)); + } + + let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat(); + let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat(); + + let a_broadcast = a.broadcast(a_broadcast_shape.as_slice()); + let b_broadcast = b.broadcast(b_broadcast_shape.as_slice()); + + let out_row_stride = output.stride(output.ndim() - 2); + let out_batches = output + .data_mut() + .unwrap() + .chunks_mut(out_row_stride * a_rows); + + let a_zero = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32; + let b_zero = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32; + + a_broadcast + .inner_iter::<2>() + .zip(b_broadcast.inner_iter::<2>()) + .zip(out_batches) + .par_bridge() + .for_each(|((a_mat, b_mat), out_mat)| { + let [m, k] = a_mat.shape(); + let [_k, n] = b_mat.shape(); + + // Do some extremely rudimentary cache blocking. + for col_block in range_chunks(0..n, 32) { + for depth_block in range_chunks(0..k, 32) { + for row_block in range_chunks(0..m, 32) { + for j in col_block.clone() { + for i in row_block.clone() { + let mut out = 0i32; + for k in depth_block.clone() { + let a = unsafe { *a_mat.get_unchecked([i, k]) } as i32 - a_zero; + let b = unsafe { *b_mat.get_unchecked([k, j]) } as i32 - b_zero; + out += a * b; + } + unsafe { + let el = out_mat.get_unchecked_mut((i * out_row_stride) + j); + if depth_block.start == 0 { + el.write(out); + } else { + el.write(el.assume_init() + out); + } + } + } + } + } + } + } + }); + + // Safety: Loop above initialized all output elements. + let output = unsafe { output.assume_init() }; + + Ok(output) +} + +#[derive(Debug)] +pub struct MatMulInteger {} + +impl Operator for MatMulInteger { + fn name(&self) -> &str { + "MatMulInteger" + } + + fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { + let a = inputs.require_as(0)?; + let b = inputs.require_as(1)?; + let a_zero_point = inputs.get_as(2)?; + let b_zero_point = inputs.get_as(3)?; + matmul_integer(pool, a, b, a_zero_point, b_zero_point).into_op_result() + } +} + #[cfg(test)] mod tests { use std::error::Error; @@ -277,10 +389,11 @@ mod tests { use rten_tensor::{Tensor, TensorView, TensorViewMut}; use crate::gemm::gemm; + use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::tests::new_pool; use crate::tensor_pool::AutoReturn; - use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError}; + use super::{gemm_op, matmul, matmul_impl, matmul_integer, MatmulStrategy, OpError}; fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) { c.make_contiguous(); @@ -323,6 +436,53 @@ mod tests { }); } + fn reference_matmul_integer( + a: TensorView, + b: TensorView, + a_zero_point: Option>, + b_zero_point: Option>, + ) -> Tensor { + let a_batch_dims = a.ndim() - 2; + let b_batch_dims = b.ndim() - 2; + + let a_prefix = &a.shape()[..a.ndim() - 2]; + let b_prefix = &b.shape()[..b.ndim() - 2]; + let out_prefix = broadcast_shapes(a_prefix, b_prefix).unwrap(); + let mut out_shape = out_prefix.to_vec(); + out_shape.push(a.size(a.ndim() - 2)); + out_shape.push(b.size(b.ndim() - 1)); + let mut out = Tensor::::zeros(&out_shape); + + let a_bcast = [out_prefix.as_slice(), &a.shape()[a_batch_dims..]].concat(); + let b_bcast = [out_prefix.as_slice(), &b.shape()[b_batch_dims..]].concat(); + + let a_zero_point = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32; + let b_zero_point = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32; + + a.broadcast(a_bcast.as_slice()) + .inner_iter::<2>() + .zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>()) + .zip(out.inner_iter_mut::<2>()) + .for_each(|((a, b), mut c)| { + let [n_rows, n_cols] = c.shape(); + let depth = a.size(1); + + for i in 0..n_rows { + for j in 0..n_cols { + let mut y = 0; + for k in 0..depth { + let a_el = (a[[i, k]] as i32) - a_zero_point; + let b_el = (b[[k, j]] as i32) - b_zero_point; + y += a_el * b_el; + } + c[[i, j]] = y; + } + } + }); + + out + } + #[test] fn test_gemm_op() -> Result<(), Box> { let pool = new_pool(); @@ -573,6 +733,60 @@ mod tests { } } + #[test] + fn test_matmul_integer() -> Result<(), Box> { + struct Case { + a: Tensor, + b: Tensor, + a_zero_point: Option>, + b_zero_point: Option>, + } + + let cases = [ + // No zero point + Case { + a: Tensor::from([[1, 2], [3, 4]]), + b: Tensor::from([[5, 6], [7, 8]]), + a_zero_point: None, + b_zero_point: None, + }, + // Scalar zero points + Case { + a: Tensor::from([[1, 2], [3, 4]]), + b: Tensor::from([[5, 6], [7, 8]]), + a_zero_point: Some(Tensor::from(127)), + b_zero_point: Some(Tensor::from(-50)), + }, + ]; + + let pool = new_pool(); + + for Case { + a, + b, + a_zero_point, + b_zero_point, + } in cases + { + let result = matmul_integer( + &pool, + a.view(), + b.view(), + a_zero_point.as_ref().map(|zp| zp.view()), + b_zero_point.as_ref().map(|zp| zp.view()), + )?; + let expected = reference_matmul_integer( + a.view(), + b.view(), + a_zero_point.as_ref().map(|zp| zp.view()), + b_zero_point.as_ref().map(|zp| zp.view()), + ); + assert_eq!(result, expected); + } + + Ok(()) + } + #[test] #[ignore] fn bench_matmul() { diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 087e5c49..88587385 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -82,7 +82,7 @@ pub use layout::{ expand, flatten, reshape, squeeze, squeeze_in_place, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose, Unsqueeze, }; -pub use matmul::{gemm_op, matmul, Gemm, MatMul}; +pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulInteger}; pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression}; pub use norm::{ batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax, diff --git a/src/schema.fbs b/src/schema.fbs index c09b09aa..1b658644 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -121,6 +121,7 @@ enum OperatorType: ubyte { DequantizeLinear, QuantizeLinear, DynamicQuantizeLinear, + MatMulInteger, } enum RNNDirection: ubyte { diff --git a/src/schema_generated.rs b/src/schema_generated.rs index e7eb1667..5f74b3cd 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -18,13 +18,13 @@ pub const ENUM_MIN_OPERATOR_TYPE: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_OPERATOR_TYPE: u8 = 107; +pub const ENUM_MAX_OPERATOR_TYPE: u8 = 108; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 108] = [ +pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 109] = [ OperatorType::Add, OperatorType::ArgMin, OperatorType::ArgMax, @@ -133,6 +133,7 @@ pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 108] = [ OperatorType::DequantizeLinear, OperatorType::QuantizeLinear, OperatorType::DynamicQuantizeLinear, + OperatorType::MatMulInteger, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -248,9 +249,10 @@ impl OperatorType { pub const DequantizeLinear: Self = Self(105); pub const QuantizeLinear: Self = Self(106); pub const DynamicQuantizeLinear: Self = Self(107); + pub const MatMulInteger: Self = Self(108); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 107; + pub const ENUM_MAX: u8 = 108; pub const ENUM_VALUES: &'static [Self] = &[ Self::Add, Self::ArgMin, @@ -360,6 +362,7 @@ impl OperatorType { Self::DequantizeLinear, Self::QuantizeLinear, Self::DynamicQuantizeLinear, + Self::MatMulInteger, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -472,6 +475,7 @@ impl OperatorType { Self::DequantizeLinear => Some("DequantizeLinear"), Self::QuantizeLinear => Some("QuantizeLinear"), Self::DynamicQuantizeLinear => Some("DynamicQuantizeLinear"), + Self::MatMulInteger => Some("MatMulInteger"), _ => None, } }