Skip to content

Commit

Permalink
Merge pull request #355 from robertknight/gemm-executor-default
Browse files Browse the repository at this point in the history
Make `Gemm` operator implementation generic
  • Loading branch information
robertknight authored Sep 12, 2024
2 parents 0bf9014 + 0710f07 commit d73b89e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 34 deletions.
24 changes: 16 additions & 8 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,20 @@ impl<'a, T> GemmInputB<'a, T> {
/// This computes `output = alpha * (a @ b) + beta * output` where `@` is
/// matrix multiplication.
#[allow(unused)]
pub fn gemm(
out_data: &mut [f32],
pub fn gemm<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
out_data: &mut [OutT],
out_row_stride: usize,
a: Matrix,
b: Matrix,
a: Matrix<LhsT>,
b: Matrix<RhsT>,
alpha: f32,
beta: f32,
) {
beta: OutT,
) where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
// This heap-allocates a new kernel on each call. That's OK because this
// is very cheap relative to the large matmuls we expect to be doing, but
// would be good to avoid for small inputs.
GemmExecutor::new().gemm(
GemmExecutor::default().gemm(
out_data,
out_row_stride,
GemmInputA::Unpacked(a),
Expand All @@ -240,7 +242,7 @@ pub fn gemm(

/// Executes matrix multiplication operations.
///
/// For simple use cases, the standalone [gemm] function can be used.
/// For simple use cases, the standalone [`gemm`] function can be used.
/// GemmExecutor provides a more advanced API that enables features such as
/// performing matrix multiplications with pre-packed inputs.
///
Expand Down Expand Up @@ -569,6 +571,12 @@ impl GemmExecutor<f32, f32, f32> {
}
}

impl Default for GemmExecutor<f32, f32, f32> {
fn default() -> Self {
Self::new()
}
}

/// Return the block size for the K / depth dimension of a GEMM operation.
fn depth_block_size(a_cols: usize) -> usize {
256.min(a_cols)
Expand Down
65 changes: 39 additions & 26 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,36 @@ use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList};
use crate::tensor_pool::{AutoReturn, TensorPool};

#[derive(Debug)]
pub struct Gemm {
pub alpha: f32,
pub beta: f32,
pub transpose_a: bool,
pub transpose_b: bool,
}

/// Compute the General Matrix Multiplication (GEMM) `c = alpha * (ab) + beta * c`.
///
/// If `transpose_a` or `transpose_b` are set, the `a` and `b` inputs
/// respectively are transposed before multiplying them.
///
/// nb. This is named `gemm_op` to avoid confusion with `gemm::gemm`.
pub fn gemm_op(
pub fn gemm_op<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
pool: &TensorPool,
a: TensorView,
b: TensorView,
c: Option<TensorView>,
a: TensorView<LhsT>,
b: TensorView<RhsT>,
c: Option<TensorView<OutT>>,
alpha: f32,
beta: f32,
beta: OutT,
transpose_a: bool,
transpose_b: bool,
) -> Result<Tensor, OpError> {
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
check_dims!(a, 2);
check_dims!(b, 2);

let a = if transpose_a { a.transposed() } else { a };
let b = if transpose_b { b.transposed() } else { b };

let out_shape = &[a.size(0), b.size(1)][..];
let gemm = GemmExecutor::new();
let gemm = GemmExecutor::<LhsT, RhsT, OutT>::default();

let output = match c {
Some(c) if beta != 0. => {
Some(c) if beta != OutT::zero() => {
if !c.can_broadcast_to(out_shape) {
return Err(OpError::IncompatibleInputShapes(
"Cannot broadcast c to output shape",
Expand Down Expand Up @@ -80,6 +75,14 @@ pub fn gemm_op(
Ok(output)
}

#[derive(Debug)]
pub struct Gemm {
pub alpha: f32,
pub beta: f32,
pub transpose_a: bool,
pub transpose_b: bool,
}

impl Operator for Gemm {
fn name(&self) -> &str {
"Gemm"
Expand All @@ -89,7 +92,7 @@ impl Operator for Gemm {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
let c = inputs.get_as(2)?;
gemm_op(
gemm_op::<f32, f32, f32>(
pool,
a,
b,
Expand All @@ -116,17 +119,26 @@ enum MatmulStrategy {
Batch,
}

pub fn matmul(pool: &TensorPool, a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
matmul_impl(pool, &GemmExecutor::new(), a, b, MatmulStrategy::Auto)
pub fn matmul<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pool: &TensorPool,
a: TensorView<LhsT>,
b: TensorView<RhsT>,
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
matmul_impl(pool, a, b, MatmulStrategy::Auto)
}

fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pool: &TensorPool,
gemm: &GemmExecutor<LhsT, RhsT, OutT>,
a: TensorView<LhsT>,
b: TensorView<RhsT>,
strategy: MatmulStrategy,
) -> Result<Tensor<OutT>, OpError> {
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
}
Expand Down Expand Up @@ -164,7 +176,7 @@ fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous_in(pool).auto_return(pool);
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());
let mut output = matmul_impl(pool, gemm, a_matrix, b.clone(), strategy)?;
let mut output = matmul(pool, a_matrix, b.clone())?;
output.reshape(out_shape);
return Ok(output);
}
Expand All @@ -188,6 +200,8 @@ fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
.unwrap()
.chunks_mut(out_row_stride * a_rows);

let gemm = GemmExecutor::default();

// Prepack re-used inputs to amortize packing cost.
//
// We don't prepack when the "A" matrix is a vector because that uses a
Expand Down Expand Up @@ -248,7 +262,7 @@ impl Operator for MatMul {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
matmul(pool, a, b).into_op_result()
matmul::<f32, f32, f32>(pool, a, b).into_op_result()
}
}

Expand All @@ -262,7 +276,7 @@ mod tests {
use rten_tensor::test_util::expect_equal;
use rten_tensor::{Tensor, TensorView, TensorViewMut};

use crate::gemm::{gemm, GemmExecutor};
use crate::gemm::gemm;
use crate::ops::tests::new_pool;
use crate::tensor_pool::AutoReturn;

Expand Down Expand Up @@ -602,8 +616,7 @@ mod tests {
);
let pool = new_pool();
run_bench(trials, Some(&desc), || {
let gemm = GemmExecutor::new();
matmul_impl(&pool, &gemm, a.view(), b.view(), strategy)
matmul_impl(&pool, a.view(), b.view(), strategy)
.unwrap()
.auto_return(&pool);
});
Expand Down

0 comments on commit d73b89e

Please sign in to comment.