Skip to content

Commit

Permalink
Add non-optimized MatMulInteger operator
Browse files Browse the repository at this point in the history
With this it is possible to run a quantized GPT-2 model ... very slowly.
  • Loading branch information
robertknight committed Sep 12, 2024
1 parent 274d147 commit 2fb0e46
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 5 deletions.
1 change: 1 addition & 0 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class OperatorType(object):
DequantizeLinear = 105
QuantizeLinear = 106
DynamicQuantizeLinear = 107
MatMulInteger = 108


class RNNDirection(object):
Expand Down
2 changes: 2 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub enum OpType<'a> {
Log,
LogSoftmax(LogSoftmax),
MatMul,
MatMulInteger,
Max,
MaxPool(MaxPool),
Mean,
Expand Down Expand Up @@ -614,6 +615,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
}
),
OpType::MatMul => op!(MatMul),
OpType::MatMulInteger => op!(MatMulInteger),
OpType::Max => op!(Max),
OpType::MaxPool(args) => op_with_attrs!(MaxPool, MaxPoolAttrs, {
let pad_args = pad_args_from_padding(args.padding);
Expand Down
2 changes: 2 additions & 0 deletions src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ impl OpRegistry {
register_op!(LogSoftmax);
register_op!(LSTM);
register_op!(MatMul);
register_op!(MatMulInteger);
register_op!(Max);
register_op!(MaxPool);
register_op!(Mean);
Expand Down Expand Up @@ -610,6 +611,7 @@ impl_read_op!(LSTM, attrs_as_lstmattrs, |attrs: sg::LSTMAttrs| {
})
});
impl_read_op!(MatMul);
impl_read_op!(MatMulInteger);
impl_read_op!(Max);
impl_read_op!(
MaxPool,
Expand Down
216 changes: 215 additions & 1 deletion src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rten_tensor::{Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT};
use crate::iter_util::range_chunks;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList};
Expand Down Expand Up @@ -266,6 +267,117 @@ impl Operator for MatMul {
}
}

pub fn matmul_integer(
pool: &TensorPool,
a: TensorView<u8>,
b: TensorView<i8>,
a_zero_point: Option<TensorView<u8>>,
b_zero_point: Option<TensorView<i8>>,
) -> Result<Tensor<i32>, OpError> {
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
}

let a_rows = a.size(a.ndim() - 2);
let a_cols = a.size(a.ndim() - 1);

let b_rows = b.size(b.ndim() - 2);
let b_cols = b.size(b.ndim() - 1);

if a_cols != b_rows {
return Err(OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
));
}

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];

let out_prefix = broadcast_shapes(a_prefix, b_prefix)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?;
let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat();

let mut output = Tensor::<i32>::uninit_in(pool, out_shape);
if output.is_empty() {
// nb. We don't need to alloc from the pool here, since the buffer
// is already empty.
return Ok(Tensor::zeros(out_shape));
}

let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat();
let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat();

let a_broadcast = a.broadcast(a_broadcast_shape.as_slice());
let b_broadcast = b.broadcast(b_broadcast_shape.as_slice());

let out_row_stride = output.stride(output.ndim() - 2);
let out_batches = output
.data_mut()
.unwrap()
.chunks_mut(out_row_stride * a_rows);

let a_zero = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;
let b_zero = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;

a_broadcast
.inner_iter::<2>()
.zip(b_broadcast.inner_iter::<2>())
.zip(out_batches)
.par_bridge()
.for_each(|((a_mat, b_mat), out_mat)| {
let [m, k] = a_mat.shape();
let [_k, n] = b_mat.shape();

// Do some extremely rudimentary cache blocking.
for col_block in range_chunks(0..n, 32) {
for depth_block in range_chunks(0..k, 32) {
for row_block in range_chunks(0..m, 32) {
for j in col_block.clone() {
for i in row_block.clone() {
let mut out = 0i32;
for k in depth_block.clone() {
let a = unsafe { *a_mat.get_unchecked([i, k]) } as i32 - a_zero;
let b = unsafe { *b_mat.get_unchecked([k, j]) } as i32 - b_zero;
out += a * b;
}
unsafe {
let el = out_mat.get_unchecked_mut((i * out_row_stride) + j);
if depth_block.start == 0 {
el.write(out);
} else {
el.write(el.assume_init() + out);
}
}
}
}
}
}
}
});

// Safety: Loop above initialized all output elements.
let output = unsafe { output.assume_init() };

Ok(output)
}

#[derive(Debug)]
pub struct MatMulInteger {}

impl Operator for MatMulInteger {
fn name(&self) -> &str {
"MatMulInteger"
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
let a_zero_point = inputs.get_as(2)?;
let b_zero_point = inputs.get_as(3)?;
matmul_integer(pool, a, b, a_zero_point, b_zero_point).into_op_result()
}
}

#[cfg(test)]
mod tests {
use std::error::Error;
Expand All @@ -277,10 +389,11 @@ mod tests {
use rten_tensor::{Tensor, TensorView, TensorViewMut};

use crate::gemm::gemm;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::tests::new_pool;
use crate::tensor_pool::AutoReturn;

use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError};
use super::{gemm_op, matmul, matmul_impl, matmul_integer, MatmulStrategy, OpError};

fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) {
c.make_contiguous();
Expand Down Expand Up @@ -323,6 +436,53 @@ mod tests {
});
}

fn reference_matmul_integer(
a: TensorView<u8>,
b: TensorView<i8>,
a_zero_point: Option<TensorView<u8>>,
b_zero_point: Option<TensorView<i8>>,
) -> Tensor<i32> {
let a_batch_dims = a.ndim() - 2;
let b_batch_dims = b.ndim() - 2;

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];
let out_prefix = broadcast_shapes(a_prefix, b_prefix).unwrap();
let mut out_shape = out_prefix.to_vec();
out_shape.push(a.size(a.ndim() - 2));
out_shape.push(b.size(b.ndim() - 1));
let mut out = Tensor::<i32>::zeros(&out_shape);

let a_bcast = [out_prefix.as_slice(), &a.shape()[a_batch_dims..]].concat();
let b_bcast = [out_prefix.as_slice(), &b.shape()[b_batch_dims..]].concat();

let a_zero_point = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;
let b_zero_point = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;

a.broadcast(a_bcast.as_slice())
.inner_iter::<2>()
.zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>())
.zip(out.inner_iter_mut::<2>())
.for_each(|((a, b), mut c)| {
let [n_rows, n_cols] = c.shape();
let depth = a.size(1);

for i in 0..n_rows {
for j in 0..n_cols {
let mut y = 0;
for k in 0..depth {
let a_el = (a[[i, k]] as i32) - a_zero_point;
let b_el = (b[[k, j]] as i32) - b_zero_point;
y += a_el * b_el;
}
c[[i, j]] = y;
}
}
});

out
}

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
Expand Down Expand Up @@ -573,6 +733,60 @@ mod tests {
}
}

#[test]
fn test_matmul_integer() -> Result<(), Box<dyn Error>> {
struct Case {
a: Tensor<u8>,
b: Tensor<i8>,
a_zero_point: Option<Tensor<u8>>,
b_zero_point: Option<Tensor<i8>>,
}

let cases = [
// No zero point
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: None,
b_zero_point: None,
},
// Scalar zero points
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: Some(Tensor::from(127)),
b_zero_point: Some(Tensor::from(-50)),
},
];

let pool = new_pool();

for Case {
a,
b,
a_zero_point,
b_zero_point,
} in cases
{
let result = matmul_integer(
&pool,
a.view(),
b.view(),
a_zero_point.as_ref().map(|zp| zp.view()),
b_zero_point.as_ref().map(|zp| zp.view()),
)?;
let expected = reference_matmul_integer(
a.view(),
b.view(),
a_zero_point.as_ref().map(|zp| zp.view()),
b_zero_point.as_ref().map(|zp| zp.view()),
);
assert_eq!(result, expected);
}

Ok(())
}

#[test]
#[ignore]
fn bench_matmul() {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub use layout::{
expand, flatten, reshape, squeeze, squeeze_in_place, Expand, Flatten, Reshape, Shape, Size,
Squeeze, Transpose, Unsqueeze,
};
pub use matmul::{gemm_op, matmul, Gemm, MatMul};
pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulInteger};
pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression};
pub use norm::{
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
Expand Down
1 change: 1 addition & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum OperatorType: ubyte {
DequantizeLinear,
QuantizeLinear,
DynamicQuantizeLinear,
MatMulInteger,
}

enum RNNDirection: ubyte {
Expand Down
10 changes: 7 additions & 3 deletions src/schema_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub const ENUM_MIN_OPERATOR_TYPE: u8 = 0;
since = "2.0.0",
note = "Use associated constants instead. This will no longer be generated in 2021."
)]
pub const ENUM_MAX_OPERATOR_TYPE: u8 = 107;
pub const ENUM_MAX_OPERATOR_TYPE: u8 = 108;
#[deprecated(
since = "2.0.0",
note = "Use associated constants instead. This will no longer be generated in 2021."
)]
#[allow(non_camel_case_types)]
pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 108] = [
pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 109] = [
OperatorType::Add,
OperatorType::ArgMin,
OperatorType::ArgMax,
Expand Down Expand Up @@ -133,6 +133,7 @@ pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 108] = [
OperatorType::DequantizeLinear,
OperatorType::QuantizeLinear,
OperatorType::DynamicQuantizeLinear,
OperatorType::MatMulInteger,
];

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
Expand Down Expand Up @@ -248,9 +249,10 @@ impl OperatorType {
pub const DequantizeLinear: Self = Self(105);
pub const QuantizeLinear: Self = Self(106);
pub const DynamicQuantizeLinear: Self = Self(107);
pub const MatMulInteger: Self = Self(108);

pub const ENUM_MIN: u8 = 0;
pub const ENUM_MAX: u8 = 107;
pub const ENUM_MAX: u8 = 108;
pub const ENUM_VALUES: &'static [Self] = &[
Self::Add,
Self::ArgMin,
Expand Down Expand Up @@ -360,6 +362,7 @@ impl OperatorType {
Self::DequantizeLinear,
Self::QuantizeLinear,
Self::DynamicQuantizeLinear,
Self::MatMulInteger,
];
/// Returns the variant's name or "" if unknown.
pub fn variant_name(self) -> Option<&'static str> {
Expand Down Expand Up @@ -472,6 +475,7 @@ impl OperatorType {
Self::DequantizeLinear => Some("DequantizeLinear"),
Self::QuantizeLinear => Some("QuantizeLinear"),
Self::DynamicQuantizeLinear => Some("DynamicQuantizeLinear"),
Self::MatMulInteger => Some("MatMulInteger"),
_ => None,
}
}
Expand Down

0 comments on commit 2fb0e46

Please sign in to comment.