Skip to content

Commit

Permalink
Add non-optimized MatMulInteger operator
Browse files Browse the repository at this point in the history
With this it is possible to run a quantized GPT-2 model ... very slowly.
  • Loading branch information
robertknight committed Sep 12, 2024
1 parent 274d147 commit 979314d
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 5 deletions.
1 change: 1 addition & 0 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class OperatorType(object):
DequantizeLinear = 105
QuantizeLinear = 106
DynamicQuantizeLinear = 107
MatMulInteger = 108


class RNNDirection(object):
Expand Down
2 changes: 2 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub enum OpType<'a> {
Log,
LogSoftmax(LogSoftmax),
MatMul,
MatMulInteger,
Max,
MaxPool(MaxPool),
Mean,
Expand Down Expand Up @@ -614,6 +615,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
}
),
OpType::MatMul => op!(MatMul),
OpType::MatMulInteger => op!(MatMulInteger),
OpType::Max => op!(Max),
OpType::MaxPool(args) => op_with_attrs!(MaxPool, MaxPoolAttrs, {
let pad_args = pad_args_from_padding(args.padding);
Expand Down
2 changes: 2 additions & 0 deletions src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ impl OpRegistry {
register_op!(LogSoftmax);
register_op!(LSTM);
register_op!(MatMul);
register_op!(MatMulInteger);
register_op!(Max);
register_op!(MaxPool);
register_op!(Mean);
Expand Down Expand Up @@ -610,6 +611,7 @@ impl_read_op!(LSTM, attrs_as_lstmattrs, |attrs: sg::LSTMAttrs| {
})
});
impl_read_op!(MatMul);
impl_read_op!(MatMulInteger);
impl_read_op!(Max);
impl_read_op!(
MaxPool,
Expand Down
295 changes: 294 additions & 1 deletion src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rten_tensor::{Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT};
use crate::iter_util::range_chunks;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList};
Expand Down Expand Up @@ -266,6 +267,126 @@ impl Operator for MatMul {
}
}

pub fn matmul_integer(
pool: &TensorPool,
a: TensorView<u8>,
b: TensorView<i8>,
a_zero_point: Option<TensorView<u8>>,
b_zero_point: Option<TensorView<i8>>,
) -> Result<Tensor<i32>, OpError> {
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
}

let a_rows = a.size(a.ndim() - 2);
let a_cols = a.size(a.ndim() - 1);

let b_rows = b.size(b.ndim() - 2);
let b_cols = b.size(b.ndim() - 1);

if a_cols != b_rows {
return Err(OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
));
}

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];

let out_prefix = broadcast_shapes(a_prefix, b_prefix)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?;
let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat();

let mut output = Tensor::<i32>::uninit_in(pool, out_shape);
if output.is_empty() {
// nb. We don't need to alloc from the pool here, since the buffer
// is already empty.
return Ok(Tensor::zeros(out_shape));
}

let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat();
let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat();

let a_broadcast = a.broadcast(a_broadcast_shape.as_slice());
let b_broadcast = b.broadcast(b_broadcast_shape.as_slice());

fn is_scalar<T>(tensor: &Option<TensorView<T>>) -> bool {
tensor.as_ref().map(|zp| zp.ndim() == 0).unwrap_or(true)
}

if !is_scalar(&a_zero_point) || !is_scalar(&b_zero_point) {
return Err(OpError::UnsupportedValue(
"Only scalar zero points are supported",
));
}

let a_zero = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;
let b_zero = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;

a_broadcast
.inner_iter::<2>()
.zip(b_broadcast.inner_iter::<2>())
.zip(output.inner_iter_mut::<2>())
.par_bridge()
.for_each(|((a_mat, b_mat), mut out_mat)| {
let [m, k] = a_mat.shape();
let [bk, n] = b_mat.shape();
assert_eq!(k, bk);
assert_eq!(out_mat.shape(), [m, n]);

// Do some extremely rudimentary cache blocking.
for col_block in range_chunks(0..n, 32) {
for depth_block in range_chunks(0..k, 32) {
for row_block in range_chunks(0..m, 32) {
for j in col_block.clone() {
for i in row_block.clone() {
let mut out = 0i32;
for k in depth_block.clone() {
// Safety: `[i, k]` is in-bounds for `a_mat`.
let a = unsafe { *a_mat.get_unchecked([i, k]) } as i32 - a_zero;
// Safety: `[k, j]` is in-bounds for `b_mat`.
let b = unsafe { *b_mat.get_unchecked([k, j]) } as i32 - b_zero;
out += a * b;
}
unsafe {
// Safety: `[i, j]` is in-bounds for `b_mat`.
let el = out_mat.get_unchecked_mut([i, j]);
if depth_block.start == 0 {
el.write(out);
} else {
el.write(el.assume_init() + out);
}
}
}
}
}
}
}
});

// Safety: Loop above initialized all output elements.
let output = unsafe { output.assume_init() };

Ok(output)
}

#[derive(Debug)]
pub struct MatMulInteger {}

impl Operator for MatMulInteger {
fn name(&self) -> &str {
"MatMulInteger"
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
let a_zero_point = inputs.get_as(2)?;
let b_zero_point = inputs.get_as(3)?;
matmul_integer(pool, a, b, a_zero_point, b_zero_point).into_op_result()
}
}

#[cfg(test)]
mod tests {
use std::error::Error;
Expand All @@ -277,10 +398,11 @@ mod tests {
use rten_tensor::{Tensor, TensorView, TensorViewMut};

use crate::gemm::gemm;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::tests::new_pool;
use crate::tensor_pool::AutoReturn;

use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError};
use super::{gemm_op, matmul, matmul_impl, matmul_integer, MatmulStrategy, OpError};

fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) {
c.make_contiguous();
Expand Down Expand Up @@ -323,6 +445,53 @@ mod tests {
});
}

fn reference_matmul_integer(
a: TensorView<u8>,
b: TensorView<i8>,
a_zero_point: Option<TensorView<u8>>,
b_zero_point: Option<TensorView<i8>>,
) -> Tensor<i32> {
let a_batch_dims = a.ndim() - 2;
let b_batch_dims = b.ndim() - 2;

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];
let out_prefix = broadcast_shapes(a_prefix, b_prefix).unwrap();
let mut out_shape = out_prefix.to_vec();
out_shape.push(a.size(a.ndim() - 2));
out_shape.push(b.size(b.ndim() - 1));
let mut out = Tensor::<i32>::zeros(&out_shape);

let a_bcast = [out_prefix.as_slice(), &a.shape()[a_batch_dims..]].concat();
let b_bcast = [out_prefix.as_slice(), &b.shape()[b_batch_dims..]].concat();

let a_zero_point = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;
let b_zero_point = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;

a.broadcast(a_bcast.as_slice())
.inner_iter::<2>()
.zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>())
.zip(out.inner_iter_mut::<2>())
.for_each(|((a, b), mut c)| {
let [n_rows, n_cols] = c.shape();
let depth = a.size(1);

for i in 0..n_rows {
for j in 0..n_cols {
let mut y = 0;
for k in 0..depth {
let a_el = (a[[i, k]] as i32) - a_zero_point;
let b_el = (b[[k, j]] as i32) - b_zero_point;
y += a_el * b_el;
}
c[[i, j]] = y;
}
}
});

out
}

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
Expand Down Expand Up @@ -573,6 +742,130 @@ mod tests {
}
}

#[test]
fn test_matmul_integer() -> Result<(), Box<dyn Error>> {
struct Case {
a: Tensor<u8>,
b: Tensor<i8>,
a_zero_point: Option<Tensor<u8>>,
b_zero_point: Option<Tensor<i8>>,
expected_err: Option<OpError>,
}

let cases = [
// No zero point
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: None,
b_zero_point: None,
expected_err: None,
},
// Scalar zero points
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: Some(Tensor::from(127)),
b_zero_point: Some(Tensor::from(-50)),
expected_err: None,
},
// Non-scalar zero points
Case {
a: Tensor::from([[2, 2], [2, 2]]),
b: Tensor::from([[2, 2], [2, 2]]),
a_zero_point: Some(Tensor::from([[2, 2], [2, 2]])),
b_zero_point: None,
expected_err: Some(OpError::UnsupportedValue(
"Only scalar zero points are supported",
)),
},
Case {
a: Tensor::from([[2, 2], [2, 2]]),
b: Tensor::from([[2, 2], [2, 2]]),
a_zero_point: None,
b_zero_point: Some(Tensor::from([[2, 2], [2, 2]])),
expected_err: Some(OpError::UnsupportedValue(
"Only scalar zero points are supported",
)),
},
// Empty output
Case {
a: Tensor::zeros(&[0, 2]),
b: Tensor::zeros(&[2, 3]),
a_zero_point: None,
b_zero_point: None,
expected_err: None,
},
// Mismatched shapes
Case {
a: Tensor::zeros(&[1, 2]),
b: Tensor::zeros(&[3, 1]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
)),
},
Case {
a: Tensor::zeros(&[1]),
b: Tensor::zeros(&[3, 1]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::InvalidValue("Inputs must have >= 2 dimensions")),
},
Case {
a: Tensor::zeros(&[1, 2]),
b: Tensor::zeros(&[1]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::InvalidValue("Inputs must have >= 2 dimensions")),
},
Case {
a: Tensor::zeros(&[2, 2, 2]),
b: Tensor::zeros(&[3, 2, 2]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::IncompatibleInputShapes("Cannot broadcast shapes")),
},
];

let pool = new_pool();

for Case {
a,
b,
a_zero_point,
b_zero_point,
expected_err,
} in cases
{
let result = matmul_integer(
&pool,
a.view(),
b.view(),
a_zero_point.as_ref().map(|zp| zp.view()),
b_zero_point.as_ref().map(|zp| zp.view()),
);

match (result, expected_err) {
(Ok(result), None) => {
let expected = reference_matmul_integer(
a.view(),
b.view(),
a_zero_point.as_ref().map(|zp| zp.view()),
b_zero_point.as_ref().map(|zp| zp.view()),
);
assert_eq!(result, expected);
}
(result, expected_err) => {
assert_eq!(result.err(), expected_err);
}
}
}

Ok(())
}

#[test]
#[ignore]
fn bench_matmul() {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub use layout::{
expand, flatten, reshape, squeeze, squeeze_in_place, Expand, Flatten, Reshape, Shape, Size,
Squeeze, Transpose, Unsqueeze,
};
pub use matmul::{gemm_op, matmul, Gemm, MatMul};
pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulInteger};
pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression};
pub use norm::{
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
Expand Down
1 change: 1 addition & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum OperatorType: ubyte {
DequantizeLinear,
QuantizeLinear,
DynamicQuantizeLinear,
MatMulInteger,
}

enum RNNDirection: ubyte {
Expand Down
Loading

0 comments on commit 979314d

Please sign in to comment.