Skip to content

Commit

Permalink
Divide prepacked matmul inputs into depth blocks
Browse files Browse the repository at this point in the history
Divide prepacked into depth blocks with a size that matches the depth block size
(`kc`) used during computation.

This allows for the generic GEMM code to be agnostic of the layout of panels
within a block, as it no longer needs to be able to slice panels along the depth
dimension. Instead it just uses a depth block index to pick a panel with the
pre-determined block size. This in turn gives the GEMM kernel freedom to choose
the layout within each panel.

This partly undoes #482.
  • Loading branch information
robertknight committed Jan 4, 2025
1 parent 1168588 commit 6ec1b1b
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 58 deletions.
192 changes: 143 additions & 49 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ use kernels::Kernel;
use packing::{PackElem, PackingBuffer};

/// Left-hand or "A" GEMM input that has been pre-packed.
///
/// The pre-packed input is divided into column blocks according to the
/// depth / K cache blocking size and then into row panels of height MR, where
/// MR is the number of rows in the kernel's tile.
///
/// Each panel has the same height, but the panels in the last column block have
/// a narrower width, and thus a smaller stride, if the number of columns does
/// not divide the depth block size.
#[derive(Clone)]
pub struct PackedAMatrix<T> {
/// Sequence of packed row panels. The exact format depends upon the kernel
Expand All @@ -36,9 +44,20 @@ pub struct PackedAMatrix<T> {
/// value.
panel_height: usize,

/// Stride of each panel in `data`.
/// Number of columns in each depth block.
depth_block: usize,

/// Stride between each depth block.
depth_block_stride: usize,

/// Stride of panels in depth blocks except for the last.
panel_stride: usize,

/// Stride of panels in the last depth block. This will be smaller than
/// `panel_stride` if the number of columns is not a multiple of the
/// depth block size.
tail_panel_stride: usize,

/// Number of rows in the unpacked matrix.
rows: usize,

Expand All @@ -54,25 +73,27 @@ pub struct PackedAMatrix<T> {
impl<T> PackedAMatrix<T> {
/// Return the packed data for a given range along the M (`rows`) and K
/// (`depth`) dimensions.
fn block(&self, rows: Range<usize>, depth: Range<usize>) -> LhsBlock<T> {
///
/// The row range start and end must be multiples of `self.panel_height`.
/// The depth block index specifies the column range as a multiple of
/// `self.depth_block`.
fn block(&self, rows: Range<usize>, depth_block_idx: usize) -> LhsBlock<T> {
assert_eq!(rows.start % self.panel_height, 0);

// Size of each column in the packed block in bytes. This assumes the
// specific column major layout for each row panel currently used by
// the kernels. This will need to change as new packed formats are
// introduced.
let col_size = self.panel_height * size_of::<T>();
let n_blocks = self.cols.div_ceil(self.depth_block);
let panel_stride = if depth_block_idx == n_blocks - 1 {
self.tail_panel_stride
} else {
self.panel_stride
};
let depth_block_offset = depth_block_idx * self.depth_block_stride;

let panel_range = rows.start / self.panel_height..rows.end.div_ceil(self.panel_height);
let start = panel_range.start * self.panel_stride + depth.start * col_size;
let end = (panel_range.end - 1) * self.panel_stride + depth.end * col_size;
let start = depth_block_offset + panel_range.start * panel_stride;
let end = depth_block_offset + panel_range.end * panel_stride;
let data = &self.data.as_bytes()[start..end];

LhsBlock::Packed {
data,
panel_stride: self.panel_stride,
panel_len: (depth.end - depth.start) * col_size,
}
LhsBlock::Packed { data, panel_stride }
}
}

Expand All @@ -85,6 +106,14 @@ impl<T> ExtractBuffer for PackedAMatrix<T> {
}

/// Right-hand or "B" GEMM input that has been pre-packed.
///
/// The pre-packed input is divided into row blocks according to the depth / K
/// cache blocking size and then into column panels of width NR, where NR is the
/// number of columns in the kernel's tile.
///
/// Each panel has the same width, but the panels in the last row block have a
/// shorter height, and thus a smaller stride, if the number of rows does not
/// divide the depth block size.
#[derive(Clone)]
pub struct PackedBMatrix<T> {
/// Sequence of packed column panels. The exact format depends upon the
Expand All @@ -95,9 +124,20 @@ pub struct PackedBMatrix<T> {
/// value.
panel_width: usize,

/// Number of rows in each depth block.
depth_block: usize,

/// Stride between each depth block.
depth_block_stride: usize,

/// Stride of each panel in `data`.
panel_stride: usize,

/// Stride of panels in the last depth block. This will be smaller than
/// `panel_stride` if the number of rows is not a multiple of the
/// depth block size.
tail_panel_stride: usize,

/// Number of rows in the unpacked matrix.
rows: usize,

Expand All @@ -113,20 +153,25 @@ pub struct PackedBMatrix<T> {
impl<T> PackedBMatrix<T> {
/// Return the packed data for a given range along the N (`cols`) and K
/// (`depth`) dimensions.
fn block(&self, cols: Range<usize>, depth: Range<usize>) -> RhsBlock<T> {
fn block(&self, cols: Range<usize>, depth_block_idx: usize) -> RhsBlock<T> {
assert_eq!(cols.start % self.panel_width, 0);

let row_size = self.panel_width * size_of::<T>();
let n_blocks = self.rows.div_ceil(self.depth_block);
let panel_stride = if depth_block_idx == n_blocks - 1 {
self.tail_panel_stride
} else {
self.panel_stride
};
let depth_block_offset = depth_block_idx * self.depth_block_stride;

let panel_range = cols.start / self.panel_width..cols.end.div_ceil(self.panel_width);
let start = panel_range.start * self.panel_stride + depth.start * row_size;
let end = (panel_range.end - 1) * self.panel_stride + depth.end * row_size;
let start = depth_block_offset + panel_range.start * panel_stride;
let end = depth_block_offset + panel_range.end * panel_stride;
let data = &self.data.as_bytes()[start..end];

RhsBlock {
data,
panel_stride: self.panel_stride,
panel_len: (depth.end - depth.start) * row_size,
panel_stride,
_marker: PhantomData,
}
}
Expand Down Expand Up @@ -371,24 +416,53 @@ impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT> GemmExecutor<LhsT, RhsT, OutT
/// Variant of [`prepack_a`](GemmExecutor::prepack_a) which takes an
/// allocator.
pub fn prepack_a_in<A: Alloc>(&self, alloc: A, a: Matrix<LhsT>) -> PackedAMatrix<LhsT> {
let layout = self.kernel.packed_a_layout(a, a.rows(), a.cols());
let depth_block = depth_block_size(a.cols());

let layout = self.kernel.packed_a_layout(a, a.rows(), depth_block);
let tail_layout = if a.cols() % depth_block != 0 {
Some(
self.kernel
.packed_a_layout(a, a.rows(), a.cols() % depth_block),
)
} else {
None
};

// Require the size to be a multiple of the alignment. This avoids the
// need for any gaps between blocks, which would have to be initialized
// after packing.
assert_eq!(layout.size() % layout.align(), 0);

let n_blocks = a.cols() / depth_block;
let total_size =
(n_blocks * layout.size()) + tail_layout.as_ref().map(|l| l.size()).unwrap_or(0);

let mut data = PackingBuffer::new();
let uninit_data = data.alloc_in(alloc, &layout);
let uninit_data = data.alloc_in(alloc, total_size, layout.align());

self.kernel
.pack_a_block(uninit_data, a, 0..a.rows(), 0..a.cols());
for (col_block, block_data) in
range_chunks(0..a.cols(), depth_block).zip(uninit_data.chunks_mut(layout.size()))
{
self.kernel
.pack_a_block(block_data, a, 0..a.rows(), col_block);
}

// Safety: We used `pack_a_block` to initialize `layout.size` bytes
// Safety: We used `pack_a_block` to initialize `total_size` bytes
unsafe {
data.set_len(layout.size());
data.set_len(total_size);
}

PackedAMatrix {
data,
rows: a.rows(),
cols: a.cols(),
panel_height: self.kernel.mr(),
depth_block,
panel_stride: layout.panel_stride(),
tail_panel_stride: tail_layout
.map(|tl| tl.panel_stride())
.unwrap_or(layout.panel_stride()),
depth_block_stride: layout.size(),
kernel_name: self.kernel.name(),
_marker: PhantomData,
}
Expand All @@ -410,24 +484,52 @@ impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT> GemmExecutor<LhsT, RhsT, OutT
/// Variant of [`prepack_b`](GemmExecutor::prepack_b) which takes an
/// allocator.
pub fn prepack_b_in<A: Alloc>(&self, alloc: A, b: Matrix<RhsT>) -> PackedBMatrix<RhsT> {
let layout = self.kernel.packed_b_layout(b.rows(), b.cols());
let depth_block = depth_block_size(b.rows());

let layout = self.kernel.packed_b_layout(depth_block, b.cols());
let tail_layout = if b.rows() % depth_block != 0 {
Some(
self.kernel
.packed_b_layout(b.rows() % depth_block, b.cols()),
)
} else {
None
};

// Require the size to be a multiple of the alignment. This avoids the
// need for any gaps between blocks, which would have to be initialized
// after packing.
assert_eq!(layout.size() % layout.align(), 0);

let n_blocks = b.rows() / depth_block;
let total_size =
(n_blocks * layout.size()) + tail_layout.as_ref().map(|l| l.size()).unwrap_or(0);
let mut data = PackingBuffer::new();
let uninit_data = data.alloc_in(alloc, &layout);
let uninit_data = data.alloc_in(alloc, total_size, layout.align());

self.kernel
.pack_b_block(uninit_data, b, 0..b.rows(), 0..b.cols());
for (row_block, block_data) in
range_chunks(0..b.rows(), depth_block).zip(uninit_data.chunks_mut(layout.size()))
{
self.kernel
.pack_b_block(block_data, b, row_block, 0..b.cols());
}

// Safety: We used `pack_b_block` to initialize `layout.size` bytes.
unsafe {
data.set_len(layout.size());
data.set_len(total_size);
}

PackedBMatrix {
data,
rows: b.rows(),
cols: b.cols(),
panel_width: self.kernel.nr(),
depth_block,
panel_stride: layout.panel_stride(),
tail_panel_stride: tail_layout
.map(|tl| tl.panel_stride())
.unwrap_or(layout.panel_stride()),
depth_block_stride: layout.size(),
kernel_name: self.kernel.name(),
_marker: PhantomData,
}
Expand Down Expand Up @@ -934,12 +1036,14 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
assert_eq!(packed.panel_height, kernel.mr());
assert_eq!(packed.rows, a.rows());
assert_eq!(packed.cols, a.cols());
assert_eq!(packed.depth_block, kc);
}
if let GemmInputB::Packed(packed) = &b {
assert_eq!(packed.kernel_name, kernel.name());
assert_eq!(packed.panel_width, kernel.nr());
assert_eq!(packed.rows, b.rows());
assert_eq!(packed.cols, b.cols());
assert_eq!(packed.depth_block, kc);
}

// Buffers for packed blocks of the matrix.
Expand All @@ -965,7 +1069,7 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(

// Loop over depth blocks. This is not parallelized because output
// tiles are shared across iterations.
for depth_range in range_chunks(0..a.cols(), kc) {
for (depth_block_idx, depth_range) in range_chunks(0..a.cols(), kc).enumerate() {
// Borrowed packing buffer for current thread. Returned after
// the GEMM block is computed.
let mut thread_local_packed_b: Option<PackingBuffer> = None;
Expand All @@ -975,7 +1079,7 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
let mut packed_b = cell.take();

let layout = kernel.packed_b_layout(depth_range.len(), col_end - col_start);
let packed_uninit = packed_b.alloc(&layout);
let packed_uninit = packed_b.alloc(layout.size(), layout.align());

match b {
GemmInputB::Unpacked(b) => kernel.pack_b_block(
Expand Down Expand Up @@ -1004,11 +1108,10 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
RhsBlock {
data: thread_local_packed_b.as_ref().unwrap().as_bytes(),
panel_stride: layout.panel_stride(),
panel_len: layout.panel_stride(),
_marker: PhantomData,
}
}),
GemmInputB::Packed(pm) => pm.block(col_range.clone(), depth_range.clone()),
GemmInputB::Packed(pm) => pm.block(col_range.clone(), depth_block_idx),
};

// Only use provided `beta` on the first write to this output
Expand Down Expand Up @@ -1043,7 +1146,7 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
};

let mut packed_a = cell.take();
let packed_uninit = packed_a.alloc(&layout);
let packed_uninit = packed_a.alloc(layout.size(), layout.align());

kernel.pack_a_block(
packed_uninit,
Expand All @@ -1060,12 +1163,9 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
LhsBlock::Packed {
data: thread_local_packed_a.as_ref().unwrap().as_bytes(),
panel_stride: layout.panel_stride(),
panel_len: layout.panel_stride(),
}
}),
GemmInputA::Packed(pm) => {
pm.block(row_range.clone(), depth_range.clone())
}
GemmInputA::Packed(pm) => pm.block(row_range.clone(), depth_block_idx),
};

gemm_block(
Expand Down Expand Up @@ -1102,9 +1202,6 @@ enum LhsBlock<'a, T> {

/// Stride between each row panel.
panel_stride: usize,

/// Length of each row panel.
panel_len: usize,
},

/// Unpacked A matrix. This must have a column stride of 1.
Expand All @@ -1120,9 +1217,6 @@ struct RhsBlock<'a, T> {
/// Stride between each column panel.
panel_stride: usize,

/// Size between each column panel.
panel_len: usize,

_marker: PhantomData<T>,
}

Expand Down Expand Up @@ -1159,7 +1253,7 @@ fn gemm_block<LhsT, RhsT, OutT: GemmOutT>(
.enumerate()
.for_each(|(block_col_tile, col_tile)| {
let b_panel_offset = block_col_tile * b.panel_stride;
let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_len];
let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_stride];

// Loop over row tiles.
for (block_row_tile, row_tile) in row_tiles.clone().enumerate() {
Expand All @@ -1172,10 +1266,10 @@ fn gemm_block<LhsT, RhsT, OutT: GemmOutT>(
LhsBlock::Packed {
data,
panel_stride,
panel_len,
// panel_len,
} => {
let a_panel_offset = block_row_tile * panel_stride;
let a_panel = &data[a_panel_offset..a_panel_offset + panel_len];
let a_panel = &data[a_panel_offset..a_panel_offset + panel_stride];
kernels::Lhs::Packed(a_panel)
}
LhsBlock::Unpacked(mat) => {
Expand Down
4 changes: 2 additions & 2 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum Lhs<'a, T> {
/// The packed block is expected to be organized as a sequence of panels with
/// stride [`panel_stride`](PackedInfo::panel_stride), but the kernel is
/// otherwise free to choose the layout.
#[derive(Clone, Debug, PartialEq)]
pub struct PackedLayout {
size: usize,
align: usize,
Expand Down Expand Up @@ -130,8 +131,7 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
/// Return a name for this kernel for use in logging etc.
fn name(&self) -> &'static str;

/// Return the layout of a packing buffer required to pack a block of `a`
/// of size `rows x cols`.
/// Return the layout of a packing buffer required to pack an A / LHS input.
fn packed_a_layout(&self, a: Matrix<LhsT>, rows: usize, cols: usize) -> PackedLayout;

/// Pack a block of the LHS / "A" input for use by this kernel.
Expand Down
Loading

0 comments on commit 6ec1b1b

Please sign in to comment.