diff --git a/src/gemm.rs b/src/gemm.rs index 64ea0539..80e6a28f 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -26,6 +26,14 @@ use kernels::Kernel; use packing::{PackElem, PackingBuffer}; /// Left-hand or "A" GEMM input that has been pre-packed. +/// +/// The pre-packed input is divided into column blocks according to the +/// depth / K cache blocking size and then into row panels of height MR, where +/// MR is the number of rows in the kernel's tile. +/// +/// Each panel has the same height, but the panels in the last column block have +/// a narrower width, and thus a smaller stride, if the number of columns does +/// not divide the depth block size. #[derive(Clone)] pub struct PackedAMatrix { /// Sequence of packed row panels. The exact format depends upon the kernel @@ -36,9 +44,20 @@ pub struct PackedAMatrix { /// value. panel_height: usize, - /// Stride of each panel in `data`. + /// Number of columns in each depth block. + depth_block: usize, + + /// Stride between each depth block. + depth_block_stride: usize, + + /// Stride of panels in depth blocks except for the last. panel_stride: usize, + /// Stride of panels in the last depth block. This will be smaller than + /// `panel_stride` if the number of columns is not a multiple of the + /// depth block size. + tail_panel_stride: usize, + /// Number of rows in the unpacked matrix. rows: usize, @@ -54,25 +73,27 @@ pub struct PackedAMatrix { impl PackedAMatrix { /// Return the packed data for a given range along the M (`rows`) and K /// (`depth`) dimensions. - fn block(&self, rows: Range, depth: Range) -> LhsBlock { + /// + /// The row range start and end must be multiples of `self.panel_height`. + /// The depth block index specifies the column range as a multiple of + /// `self.depth_block`. + fn block(&self, rows: Range, depth_block_idx: usize) -> LhsBlock { assert_eq!(rows.start % self.panel_height, 0); - // Size of each column in the packed block in bytes. This assumes the - // specific column major layout for each row panel currently used by - // the kernels. This will need to change as new packed formats are - // introduced. - let col_size = self.panel_height * size_of::(); + let n_blocks = self.cols.div_ceil(self.depth_block); + let panel_stride = if depth_block_idx == n_blocks - 1 { + self.tail_panel_stride + } else { + self.panel_stride + }; + let depth_block_offset = depth_block_idx * self.depth_block_stride; let panel_range = rows.start / self.panel_height..rows.end.div_ceil(self.panel_height); - let start = panel_range.start * self.panel_stride + depth.start * col_size; - let end = (panel_range.end - 1) * self.panel_stride + depth.end * col_size; + let start = depth_block_offset + panel_range.start * panel_stride; + let end = depth_block_offset + panel_range.end * panel_stride; let data = &self.data.as_bytes()[start..end]; - LhsBlock::Packed { - data, - panel_stride: self.panel_stride, - panel_len: (depth.end - depth.start) * col_size, - } + LhsBlock::Packed { data, panel_stride } } } @@ -85,6 +106,14 @@ impl ExtractBuffer for PackedAMatrix { } /// Right-hand or "B" GEMM input that has been pre-packed. +/// +/// The pre-packed input is divided into row blocks according to the depth / K +/// cache blocking size and then into column panels of width NR, where NR is the +/// number of columns in the kernel's tile. +/// +/// Each panel has the same width, but the panels in the last row block have a +/// shorter height, and thus a smaller stride, if the number of rows does not +/// divide the depth block size. #[derive(Clone)] pub struct PackedBMatrix { /// Sequence of packed column panels. The exact format depends upon the @@ -95,9 +124,20 @@ pub struct PackedBMatrix { /// value. panel_width: usize, + /// Number of rows in each depth block. + depth_block: usize, + + /// Stride between each depth block. + depth_block_stride: usize, + /// Stride of each panel in `data`. panel_stride: usize, + /// Stride of panels in the last depth block. This will be smaller than + /// `panel_stride` if the number of rows is not a multiple of the + /// depth block size. + tail_panel_stride: usize, + /// Number of rows in the unpacked matrix. rows: usize, @@ -113,20 +153,25 @@ pub struct PackedBMatrix { impl PackedBMatrix { /// Return the packed data for a given range along the N (`cols`) and K /// (`depth`) dimensions. - fn block(&self, cols: Range, depth: Range) -> RhsBlock { + fn block(&self, cols: Range, depth_block_idx: usize) -> RhsBlock { assert_eq!(cols.start % self.panel_width, 0); - let row_size = self.panel_width * size_of::(); + let n_blocks = self.rows.div_ceil(self.depth_block); + let panel_stride = if depth_block_idx == n_blocks - 1 { + self.tail_panel_stride + } else { + self.panel_stride + }; + let depth_block_offset = depth_block_idx * self.depth_block_stride; let panel_range = cols.start / self.panel_width..cols.end.div_ceil(self.panel_width); - let start = panel_range.start * self.panel_stride + depth.start * row_size; - let end = (panel_range.end - 1) * self.panel_stride + depth.end * row_size; + let start = depth_block_offset + panel_range.start * panel_stride; + let end = depth_block_offset + panel_range.end * panel_stride; let data = &self.data.as_bytes()[start..end]; RhsBlock { data, - panel_stride: self.panel_stride, - panel_len: (depth.end - depth.start) * row_size, + panel_stride, _marker: PhantomData, } } @@ -371,16 +416,40 @@ impl GemmExecutor(&self, alloc: A, a: Matrix) -> PackedAMatrix { - let layout = self.kernel.packed_a_layout(a, a.rows(), a.cols()); + let depth_block = depth_block_size(a.cols()); + + let layout = self.kernel.packed_a_layout(a, a.rows(), depth_block); + let tail_layout = if a.cols() % depth_block != 0 { + Some( + self.kernel + .packed_a_layout(a, a.rows(), a.cols() % depth_block), + ) + } else { + None + }; + + // Require the size to be a multiple of the alignment. This avoids the + // need for any gaps between blocks, which would have to be initialized + // after packing. + assert_eq!(layout.size() % layout.align(), 0); + + let n_blocks = a.cols() / depth_block; + let total_size = + (n_blocks * layout.size()) + tail_layout.as_ref().map(|l| l.size()).unwrap_or(0); + let mut data = PackingBuffer::new(); - let uninit_data = data.alloc_in(alloc, &layout); + let uninit_data = data.alloc_in(alloc, total_size, layout.align()); - self.kernel - .pack_a_block(uninit_data, a, 0..a.rows(), 0..a.cols()); + for (col_block, block_data) in + range_chunks(0..a.cols(), depth_block).zip(uninit_data.chunks_mut(layout.size())) + { + self.kernel + .pack_a_block(block_data, a, 0..a.rows(), col_block); + } - // Safety: We used `pack_a_block` to initialize `layout.size` bytes + // Safety: We used `pack_a_block` to initialize `total_size` bytes unsafe { - data.set_len(layout.size()); + data.set_len(total_size); } PackedAMatrix { @@ -388,7 +457,12 @@ impl GemmExecutor GemmExecutor(&self, alloc: A, b: Matrix) -> PackedBMatrix { - let layout = self.kernel.packed_b_layout(b.rows(), b.cols()); + let depth_block = depth_block_size(b.rows()); + + let layout = self.kernel.packed_b_layout(depth_block, b.cols()); + let tail_layout = if b.rows() % depth_block != 0 { + Some( + self.kernel + .packed_b_layout(b.rows() % depth_block, b.cols()), + ) + } else { + None + }; + + // Require the size to be a multiple of the alignment. This avoids the + // need for any gaps between blocks, which would have to be initialized + // after packing. + assert_eq!(layout.size() % layout.align(), 0); + + let n_blocks = b.rows() / depth_block; + let total_size = + (n_blocks * layout.size()) + tail_layout.as_ref().map(|l| l.size()).unwrap_or(0); let mut data = PackingBuffer::new(); - let uninit_data = data.alloc_in(alloc, &layout); + let uninit_data = data.alloc_in(alloc, total_size, layout.align()); - self.kernel - .pack_b_block(uninit_data, b, 0..b.rows(), 0..b.cols()); + for (row_block, block_data) in + range_chunks(0..b.rows(), depth_block).zip(uninit_data.chunks_mut(layout.size())) + { + self.kernel + .pack_b_block(block_data, b, row_block, 0..b.cols()); + } // Safety: We used `pack_b_block` to initialize `layout.size` bytes. unsafe { - data.set_len(layout.size()); + data.set_len(total_size); } PackedBMatrix { @@ -427,7 +524,12 @@ impl GemmExecutor( assert_eq!(packed.panel_height, kernel.mr()); assert_eq!(packed.rows, a.rows()); assert_eq!(packed.cols, a.cols()); + assert_eq!(packed.depth_block, kc); } if let GemmInputB::Packed(packed) = &b { assert_eq!(packed.kernel_name, kernel.name()); assert_eq!(packed.panel_width, kernel.nr()); assert_eq!(packed.rows, b.rows()); assert_eq!(packed.cols, b.cols()); + assert_eq!(packed.depth_block, kc); } // Buffers for packed blocks of the matrix. @@ -965,7 +1069,7 @@ fn gemm_impl( // Loop over depth blocks. This is not parallelized because output // tiles are shared across iterations. - for depth_range in range_chunks(0..a.cols(), kc) { + for (depth_block_idx, depth_range) in range_chunks(0..a.cols(), kc).enumerate() { // Borrowed packing buffer for current thread. Returned after // the GEMM block is computed. let mut thread_local_packed_b: Option = None; @@ -975,7 +1079,7 @@ fn gemm_impl( let mut packed_b = cell.take(); let layout = kernel.packed_b_layout(depth_range.len(), col_end - col_start); - let packed_uninit = packed_b.alloc(&layout); + let packed_uninit = packed_b.alloc(layout.size(), layout.align()); match b { GemmInputB::Unpacked(b) => kernel.pack_b_block( @@ -1004,11 +1108,10 @@ fn gemm_impl( RhsBlock { data: thread_local_packed_b.as_ref().unwrap().as_bytes(), panel_stride: layout.panel_stride(), - panel_len: layout.panel_stride(), _marker: PhantomData, } }), - GemmInputB::Packed(pm) => pm.block(col_range.clone(), depth_range.clone()), + GemmInputB::Packed(pm) => pm.block(col_range.clone(), depth_block_idx), }; // Only use provided `beta` on the first write to this output @@ -1043,7 +1146,7 @@ fn gemm_impl( }; let mut packed_a = cell.take(); - let packed_uninit = packed_a.alloc(&layout); + let packed_uninit = packed_a.alloc(layout.size(), layout.align()); kernel.pack_a_block( packed_uninit, @@ -1060,12 +1163,9 @@ fn gemm_impl( LhsBlock::Packed { data: thread_local_packed_a.as_ref().unwrap().as_bytes(), panel_stride: layout.panel_stride(), - panel_len: layout.panel_stride(), } }), - GemmInputA::Packed(pm) => { - pm.block(row_range.clone(), depth_range.clone()) - } + GemmInputA::Packed(pm) => pm.block(row_range.clone(), depth_block_idx), }; gemm_block( @@ -1102,9 +1202,6 @@ enum LhsBlock<'a, T> { /// Stride between each row panel. panel_stride: usize, - - /// Length of each row panel. - panel_len: usize, }, /// Unpacked A matrix. This must have a column stride of 1. @@ -1120,9 +1217,6 @@ struct RhsBlock<'a, T> { /// Stride between each column panel. panel_stride: usize, - /// Size between each column panel. - panel_len: usize, - _marker: PhantomData, } @@ -1159,7 +1253,7 @@ fn gemm_block( .enumerate() .for_each(|(block_col_tile, col_tile)| { let b_panel_offset = block_col_tile * b.panel_stride; - let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_len]; + let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_stride]; // Loop over row tiles. for (block_row_tile, row_tile) in row_tiles.clone().enumerate() { @@ -1172,10 +1266,10 @@ fn gemm_block( LhsBlock::Packed { data, panel_stride, - panel_len, + // panel_len, } => { let a_panel_offset = block_row_tile * panel_stride; - let a_panel = &data[a_panel_offset..a_panel_offset + panel_len]; + let a_panel = &data[a_panel_offset..a_panel_offset + panel_stride]; kernels::Lhs::Packed(a_panel) } LhsBlock::Unpacked(mat) => { diff --git a/src/gemm/kernels.rs b/src/gemm/kernels.rs index db9bc7c1..dcc2caac 100644 --- a/src/gemm/kernels.rs +++ b/src/gemm/kernels.rs @@ -42,6 +42,7 @@ pub enum Lhs<'a, T> { /// The packed block is expected to be organized as a sequence of panels with /// stride [`panel_stride`](PackedInfo::panel_stride), but the kernel is /// otherwise free to choose the layout. +#[derive(Clone, Debug, PartialEq)] pub struct PackedLayout { size: usize, align: usize, @@ -130,8 +131,7 @@ pub unsafe trait Kernel: Sync { /// Return a name for this kernel for use in logging etc. fn name(&self) -> &'static str; - /// Return the layout of a packing buffer required to pack a block of `a` - /// of size `rows x cols`. + /// Return the layout of a packing buffer required to pack an A / LHS input. fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout; /// Pack a block of the LHS / "A" input for use by this kernel. diff --git a/src/gemm/packing.rs b/src/gemm/packing.rs index 49ac146f..b287ec1d 100644 --- a/src/gemm/packing.rs +++ b/src/gemm/packing.rs @@ -237,10 +237,10 @@ impl PackingBuffer { /// /// Returns an uninitialized slice of `layout.size()` bytes which the /// caller must fill. - pub fn alloc(&mut self, layout: &PackedLayout) -> &mut [MaybeUninit] { - assert!(layout.align() <= align_of::()); + pub fn alloc(&mut self, size: usize, align: usize) -> &mut [MaybeUninit] { + assert!(align <= align_of::()); - let buf_len = layout.size().div_ceil(size_of::()); + let buf_len = size.div_ceil(size_of::()); self.buf.clear(); self.buf.reserve(buf_len); self.used_len = 0; @@ -257,11 +257,12 @@ impl PackingBuffer { pub fn alloc_in( &mut self, alloc: A, - layout: &PackedLayout, + size: usize, + align: usize, ) -> &mut [MaybeUninit] { - assert!(layout.align() <= align_of::()); + assert!(align <= align_of::()); - let buf_len = layout.size().div_ceil(size_of::()); + let buf_len = size.div_ceil(size_of::()); self.buf = alloc.alloc::(buf_len); self.used_len = 0; @@ -327,7 +328,7 @@ mod tests { assert_eq!(buf.as_bytes().len(), 0); let layout = PackedLayout::new(size, align, panel_stride); - let uninit_data = buf.alloc(&layout); + let uninit_data = buf.alloc(layout.size(), layout.align()); assert_eq!(uninit_data.len(), layout.size()); uninit_data.fill(MaybeUninit::new(0));