Skip to content

Commit

Permalink
Improve the documentation for the Kernel trait
Browse files Browse the repository at this point in the history
In particular, note the caveat that the input may be uninitialized if `beta` is
zero. This is an aspect that ought to be modeled in the types but currently is
not.
  • Loading branch information
robertknight committed Sep 10, 2024
1 parent 22585ff commit 11f5722
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,14 @@ unsafe fn simd_gemm<S: SimdFloat, const MR: usize, const NR_REGS: usize>(
}
}

/// Kernel that computes a small tile of a matrix multiplication output.
/// Kernel that computes a small tile of a general matrix multiplication (GEMM) or
/// general matrix-vector multiplication (GEMV).
///
/// The kernel corresponds to Loop 6 (the "microkernel") in Page 4 of [^1]. The
/// This trait is an interface for the architecture-specific inner loop for matrix
/// multiplication and matrix-vector multiplication, as well as the methods that
/// pack the input matrices into a format that is efficient for the kernel to use.
///
/// The GEMM kernel corresponds to Loop 6 (the "microkernel") in Page 4 of [^1]. The
/// tile size depends upon the kernel and is specified by the `MR` and `NR`
/// associated constants. See Section 3.2 [^1] for theory behind choosing the
/// `MR` and `NR` values.
Expand Down Expand Up @@ -365,14 +370,21 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
cols: Range<usize>,
);

/// Compute a tile of the output matrix. The output is stored in row-major
/// order with `MR` rows and `NR` columns, a row stride of `tile_row_stride`
/// and column stride of 1.
/// Compute a tile of the output matrix.
///
/// The output is stored in row-major order with `MR` rows and `NR` columns,
/// a row stride of `tile_row_stride` and column stride of 1.
///
/// The `a` and `b` inputs are the input matrices packed by the `pack_a_block`
/// and `pack_b_block` methods. The `depth` input specifies the number of
/// columns of A and rows of B that are in the packed inputs.
///
/// # Safety
///
/// The caller must ensure that `tile_ptr` points to a buffer of the correct
/// size.
/// size. If `beta` is zero then the output may be uninitialized and must not
/// be read by the implementation. If `beta` is non-zero then the output must
/// be initialized and the implementation will read from it.
unsafe fn kernel(
&self,
tile_ptr: *mut OutT,
Expand All @@ -398,7 +410,7 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
///
/// # Safety
///
/// The caller must ensure that the kernel is supported on the current
/// system.
/// If `beta` is non-zero then the output may be uninitialized and must not be
/// read by the implementation.
fn gemv_kernel(&self, out: &mut [OutT], a: &[LhsT], b: Matrix<RhsT>, alpha: f32, beta: OutT);
}

0 comments on commit 11f5722

Please sign in to comment.