Skip to content

Commit

Permalink
Merge pull request #534 from robertknight/slice-write-helper
Browse files Browse the repository at this point in the history
Add helper to ensure packing buffer is fully written, remove unused col-major packing code
  • Loading branch information
robertknight authored Jan 12, 2025
2 parents 27f8c96 + d6a50d4 commit 48c3aef
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 113 deletions.
2 changes: 0 additions & 2 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2448,8 +2448,6 @@ mod tests {
run_gemm_bench::<f32, f32, f32>(&cases, Format::Csv);
}

// Like `bench_pack_a`, but this does include allocation costs, so is
// relevant for ops which prepack inputs (eg. batched matmul).
#[test]
#[ignore]
fn bench_prepack_a() {
Expand Down
166 changes: 55 additions & 111 deletions src/gemm/packing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,42 @@ use super::kernels::PackedLayout;
use crate::iter_util::range_chunks;
use crate::slice_cast::{cast_pod_mut_slice, cast_pod_slice};

/// Helper for incrementally filling a slice.
struct SliceWriter<'a, T> {
offset: usize,
slice: &'a mut [MaybeUninit<T>],
}

impl<'a, T> SliceWriter<'a, T> {
fn new(slice: &'a mut [MaybeUninit<T>]) -> Self {
SliceWriter { slice, offset: 0 }
}

/// Return true if the slice has been fully written.
fn completed(&self) -> bool {
self.offset == self.slice.len()
}

/// Write the next element in the slice.
unsafe fn write_unchecked(&mut self, val: T) {
debug_assert!(self.offset < self.slice.len());
self.slice.get_unchecked_mut(self.offset).write(val);
self.offset += 1;
}

/// Write `len` copies of `val` to the slice.
unsafe fn write_n_unchecked(&mut self, len: usize, val: T)
where
T: Copy,
{
debug_assert!(self.offset + len <= self.slice.len());
for i in 0..len {
self.slice.get_unchecked_mut(self.offset + i).write(val);
}
self.offset += len;
}
}

/// Return the required size and other metadata for packing an "A" matrix with
/// [`pack_a_block`].
pub fn packed_a_layout<T, const MR: usize>(rows: usize, cols: usize) -> PackedLayout {
Expand All @@ -28,117 +64,25 @@ pub fn pack_a_block<T: Copy + Default, const MR: usize>(
cols: Range<usize>,
) {
assert_eq!(out.len(), rows.len().next_multiple_of(MR) * cols.len());
let mut out = SliceWriter::new(out);

let mut out_idx = 0;
for panel_rows in range_chunks(rows, MR) {
for row in panel_rows.clone() {
for col in cols.clone() {
unsafe {
out.get_unchecked_mut(out_idx)
.write(*a.get_unchecked([row, col]));
out.write_unchecked(*a.get_unchecked([row, col]));
}
out_idx += 1;
}
}

// Pad wth zeros
for _ in panel_rows.end..panel_rows.start + MR {
for _ in cols.clone() {
unsafe { out.get_unchecked_mut(out_idx).write(T::default()) };
out_idx += 1;
}
unsafe { out.write_n_unchecked(cols.len(), T::default()) };
}
}

// Make sure we initialized the entire block.
assert_eq!(out_idx, out.len());
}

/// Pack a block of the "A" matrix for use by a GEMM kernel, in column major
/// order.
///
/// The packed buffer is laid out as a sequence of `ceil(rows.len() / MR)`
/// row panels. Each row panel has size `MR * cols.len()` and uses
/// column-major order. If `rows.len()` is not a multiple of `MR`, the
/// final panel is zero-padded.
///
/// Panics if the output buffer is not exactly the correct size.
///
/// # Safety
///
/// When this function returns, all elements of `out` will have been initialized
/// either to a value from `a`, or zero.
#[allow(unused)] // Currently unused in favor of row-major packing.
#[inline] // Allow caller to control `target_feature`s
pub fn pack_a_block_col_major<T: Copy + Default, const MR: usize>(
out: &mut [MaybeUninit<T>],
a: Matrix<T>,
rows: Range<usize>,
cols: Range<usize>,
) {
let a_rows = rows.len();
let a_cols = cols.len();
let n_panels = a_rows.next_multiple_of(MR) / MR;
let used_size = n_panels * MR * a_cols;
assert_eq!(out.len(), used_size);

// Safety: Loops below must only access valid offsets in `a_data`.
let a_data = a.storage();

let row_stride = a.row_stride();
let col_stride = a.col_stride();

for panel in 0..n_panels {
let panel_offset = panel * a_cols * MR;
let panel_start_row = panel * MR;

if a_rows - panel_start_row >= MR {
// Optimized loop for panels that don't need any padding
let a_offset = (rows.start + panel_start_row) * row_stride + cols.start * col_stride;

assert!(out.len() > panel_offset + (a_cols - 1) * MR + MR - 1);
assert!(a_data.len() > a_offset + (MR - 1) * row_stride + (a_cols - 1) * col_stride);

// Optimize for common case of unit stride as this generates better
// code.
if col_stride == 1 {
for col in 0..a_cols {
for row in 0..MR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
out.get_unchecked_mut(panel_offset + col * MR + row)
.write(*a_data.get_unchecked(a_offset + row * row_stride + col));
}
}
}
} else {
for col in 0..a_cols {
for row in 0..MR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
out.get_unchecked_mut(panel_offset + col * MR + row).write(
*a_data
.get_unchecked(a_offset + row * row_stride + col * col_stride),
);
}
}
}
}
} else {
// Fallback for final panel if padding is required
for col in 0..a_cols {
let out_col_offset = panel_offset + col * MR;
for row in 0..MR {
let a_row = rows.start + panel_start_row + row;
out[out_col_offset + row].write(if a_row < rows.end {
let offset = a_row * row_stride + (cols.start + col) * col_stride;
unsafe { *a_data.get_unchecked(offset) }
} else {
T::default()
});
}
}
}
}
assert!(out.completed());
}

/// Return the required size and other metadata for packing a "B" matrix with
Expand Down Expand Up @@ -181,16 +125,16 @@ pub fn pack_b_block<T: Copy + Default, const NR: usize>(
// Safety: Loops below must only access valid offsets in `b_data`.
let b_data = b.storage();

let mut out = SliceWriter::new(out);

for panel in 0..n_panels {
let panel_offset = panel * b_rows * NR;
let panel_start_col = panel * NR;

if b_cols - panel_start_col >= NR {
// Optimized loop for panels that don't need any padding
let b_offset =
rows.start * b_row_stride + (cols.start + panel_start_col) * b_col_stride;

assert!(out.len() >= panel_offset + (b_rows - 1) * NR + NR);
assert!(
b_data.len() > b_offset + (b_rows - 1) * b_row_stride + (NR - 1) * b_col_stride
);
Expand All @@ -200,49 +144,49 @@ pub fn pack_b_block<T: Copy + Default, const NR: usize>(
// code.
if b_col_stride == 1 {
for row in 0..b_rows {
let out_offset = panel_offset + row * NR;
let in_offset = b_offset + row * b_row_stride;
for col in 0..NR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
out.get_unchecked_mut(out_offset + col)
.write(*b_data.get_unchecked(in_offset + col));
out.write_unchecked(*b_data.get_unchecked(in_offset + col));
}
}
}
} else {
for row in 0..b_rows {
let out_offset = panel_offset + row * NR;
let in_offset = b_offset + row * b_row_stride;
for col in 0..NR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
out.get_unchecked_mut(out_offset + col)
.write(*b_data.get_unchecked(in_offset + col * b_col_stride));
out.write_unchecked(
*b_data.get_unchecked(in_offset + col * b_col_stride),
);
}
}
}
}
} else {
// Fallback for final panel if padding is required
for row in 0..b_rows {
let out_row_offset = panel_offset + row * NR;
let b_row_offset = (rows.start + row) * b_row_stride;

for col in 0..NR {
let out_col = panel_start_col + col;
let b_offset =
b_row_offset + (cols.start + panel_start_col + col) * b_col_stride;

out[out_row_offset + col].write(if out_col < b_cols {
unsafe { *b_data.get_unchecked(b_offset) }
} else {
T::default()
});
unsafe {
out.write_unchecked(if out_col < b_cols {
*b_data.get_unchecked(b_offset)
} else {
T::default()
});
}
}
}
}
}
assert!(out.completed());
}

// Element type used by [`PackingBuffer`]. This must have an alignment that is
Expand Down

0 comments on commit 48c3aef

Please sign in to comment.