diff --git a/src/gemm.rs b/src/gemm.rs index 7622063d..3dce3b19 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -2448,8 +2448,6 @@ mod tests { run_gemm_bench::(&cases, Format::Csv); } - // Like `bench_pack_a`, but this does include allocation costs, so is - // relevant for ops which prepack inputs (eg. batched matmul). #[test] #[ignore] fn bench_prepack_a() { diff --git a/src/gemm/packing.rs b/src/gemm/packing.rs index 9ef5bb5c..7d8f3120 100644 --- a/src/gemm/packing.rs +++ b/src/gemm/packing.rs @@ -7,6 +7,42 @@ use super::kernels::PackedLayout; use crate::iter_util::range_chunks; use crate::slice_cast::{cast_pod_mut_slice, cast_pod_slice}; +/// Helper for incrementally filling a slice. +struct SliceWriter<'a, T> { + offset: usize, + slice: &'a mut [MaybeUninit], +} + +impl<'a, T> SliceWriter<'a, T> { + fn new(slice: &'a mut [MaybeUninit]) -> Self { + SliceWriter { slice, offset: 0 } + } + + /// Return true if the slice has been fully written. + fn completed(&self) -> bool { + self.offset == self.slice.len() + } + + /// Write the next element in the slice. + unsafe fn write_unchecked(&mut self, val: T) { + debug_assert!(self.offset < self.slice.len()); + self.slice.get_unchecked_mut(self.offset).write(val); + self.offset += 1; + } + + /// Write `len` copies of `val` to the slice. + unsafe fn write_n_unchecked(&mut self, len: usize, val: T) + where + T: Copy, + { + debug_assert!(self.offset + len <= self.slice.len()); + for i in 0..len { + self.slice.get_unchecked_mut(self.offset + i).write(val); + } + self.offset += len; + } +} + /// Return the required size and other metadata for packing an "A" matrix with /// [`pack_a_block`]. pub fn packed_a_layout(rows: usize, cols: usize) -> PackedLayout { @@ -28,117 +64,25 @@ pub fn pack_a_block( cols: Range, ) { assert_eq!(out.len(), rows.len().next_multiple_of(MR) * cols.len()); + let mut out = SliceWriter::new(out); - let mut out_idx = 0; for panel_rows in range_chunks(rows, MR) { for row in panel_rows.clone() { for col in cols.clone() { unsafe { - out.get_unchecked_mut(out_idx) - .write(*a.get_unchecked([row, col])); + out.write_unchecked(*a.get_unchecked([row, col])); } - out_idx += 1; } } + // Pad wth zeros for _ in panel_rows.end..panel_rows.start + MR { - for _ in cols.clone() { - unsafe { out.get_unchecked_mut(out_idx).write(T::default()) }; - out_idx += 1; - } + unsafe { out.write_n_unchecked(cols.len(), T::default()) }; } } // Make sure we initialized the entire block. - assert_eq!(out_idx, out.len()); -} - -/// Pack a block of the "A" matrix for use by a GEMM kernel, in column major -/// order. -/// -/// The packed buffer is laid out as a sequence of `ceil(rows.len() / MR)` -/// row panels. Each row panel has size `MR * cols.len()` and uses -/// column-major order. If `rows.len()` is not a multiple of `MR`, the -/// final panel is zero-padded. -/// -/// Panics if the output buffer is not exactly the correct size. -/// -/// # Safety -/// -/// When this function returns, all elements of `out` will have been initialized -/// either to a value from `a`, or zero. -#[allow(unused)] // Currently unused in favor of row-major packing. -#[inline] // Allow caller to control `target_feature`s -pub fn pack_a_block_col_major( - out: &mut [MaybeUninit], - a: Matrix, - rows: Range, - cols: Range, -) { - let a_rows = rows.len(); - let a_cols = cols.len(); - let n_panels = a_rows.next_multiple_of(MR) / MR; - let used_size = n_panels * MR * a_cols; - assert_eq!(out.len(), used_size); - - // Safety: Loops below must only access valid offsets in `a_data`. - let a_data = a.storage(); - - let row_stride = a.row_stride(); - let col_stride = a.col_stride(); - - for panel in 0..n_panels { - let panel_offset = panel * a_cols * MR; - let panel_start_row = panel * MR; - - if a_rows - panel_start_row >= MR { - // Optimized loop for panels that don't need any padding - let a_offset = (rows.start + panel_start_row) * row_stride + cols.start * col_stride; - - assert!(out.len() > panel_offset + (a_cols - 1) * MR + MR - 1); - assert!(a_data.len() > a_offset + (MR - 1) * row_stride + (a_cols - 1) * col_stride); - - // Optimize for common case of unit stride as this generates better - // code. - if col_stride == 1 { - for col in 0..a_cols { - for row in 0..MR { - // Safety: Indexes are less than lengths asserted above. - unsafe { - out.get_unchecked_mut(panel_offset + col * MR + row) - .write(*a_data.get_unchecked(a_offset + row * row_stride + col)); - } - } - } - } else { - for col in 0..a_cols { - for row in 0..MR { - // Safety: Indexes are less than lengths asserted above. - unsafe { - out.get_unchecked_mut(panel_offset + col * MR + row).write( - *a_data - .get_unchecked(a_offset + row * row_stride + col * col_stride), - ); - } - } - } - } - } else { - // Fallback for final panel if padding is required - for col in 0..a_cols { - let out_col_offset = panel_offset + col * MR; - for row in 0..MR { - let a_row = rows.start + panel_start_row + row; - out[out_col_offset + row].write(if a_row < rows.end { - let offset = a_row * row_stride + (cols.start + col) * col_stride; - unsafe { *a_data.get_unchecked(offset) } - } else { - T::default() - }); - } - } - } - } + assert!(out.completed()); } /// Return the required size and other metadata for packing a "B" matrix with @@ -181,8 +125,9 @@ pub fn pack_b_block( // Safety: Loops below must only access valid offsets in `b_data`. let b_data = b.storage(); + let mut out = SliceWriter::new(out); + for panel in 0..n_panels { - let panel_offset = panel * b_rows * NR; let panel_start_col = panel * NR; if b_cols - panel_start_col >= NR { @@ -190,7 +135,6 @@ pub fn pack_b_block( let b_offset = rows.start * b_row_stride + (cols.start + panel_start_col) * b_col_stride; - assert!(out.len() >= panel_offset + (b_rows - 1) * NR + NR); assert!( b_data.len() > b_offset + (b_rows - 1) * b_row_stride + (NR - 1) * b_col_stride ); @@ -200,25 +144,23 @@ pub fn pack_b_block( // code. if b_col_stride == 1 { for row in 0..b_rows { - let out_offset = panel_offset + row * NR; let in_offset = b_offset + row * b_row_stride; for col in 0..NR { // Safety: Indexes are less than lengths asserted above. unsafe { - out.get_unchecked_mut(out_offset + col) - .write(*b_data.get_unchecked(in_offset + col)); + out.write_unchecked(*b_data.get_unchecked(in_offset + col)); } } } } else { for row in 0..b_rows { - let out_offset = panel_offset + row * NR; let in_offset = b_offset + row * b_row_stride; for col in 0..NR { // Safety: Indexes are less than lengths asserted above. unsafe { - out.get_unchecked_mut(out_offset + col) - .write(*b_data.get_unchecked(in_offset + col * b_col_stride)); + out.write_unchecked( + *b_data.get_unchecked(in_offset + col * b_col_stride), + ); } } } @@ -226,7 +168,6 @@ pub fn pack_b_block( } else { // Fallback for final panel if padding is required for row in 0..b_rows { - let out_row_offset = panel_offset + row * NR; let b_row_offset = (rows.start + row) * b_row_stride; for col in 0..NR { @@ -234,15 +175,18 @@ pub fn pack_b_block( let b_offset = b_row_offset + (cols.start + panel_start_col + col) * b_col_stride; - out[out_row_offset + col].write(if out_col < b_cols { - unsafe { *b_data.get_unchecked(b_offset) } - } else { - T::default() - }); + unsafe { + out.write_unchecked(if out_col < b_cols { + *b_data.get_unchecked(b_offset) + } else { + T::default() + }); + } } } } } + assert!(out.completed()); } // Element type used by [`PackingBuffer`]. This must have an alignment that is