Skip to content

Commit

Permalink
Merge pull request #535 from robertknight/gemm-int8-x64
Browse files Browse the repository at this point in the history
Implement u8 x i8 -> i32 GEMM kernel for x86_64 using AVX2 intrinsics
  • Loading branch information
robertknight authored Jan 13, 2025
2 parents 48c3aef + f7e4a51 commit 8fe2402
Show file tree
Hide file tree
Showing 4 changed files with 575 additions and 10 deletions.
57 changes: 47 additions & 10 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ impl Default for GemmExecutor<f32, f32, f32> {

impl Default for GemmExecutor<u8, i8, i32> {
fn default() -> Self {
#[cfg(target_arch = "x86_64")]
if let Some(gemm) = Self::from_kernel::<kernels::x86_64::Avx2Int8Kernel>(KernelType::Fma) {
return gemm;
}
Self::from_kernel::<GenericKernel>(KernelType::Generic).unwrap()
}
}
Expand Down Expand Up @@ -1539,10 +1543,37 @@ mod tests {
}
}

// Random number generator which produces values with a reduced range.
//
// This works around an issue under AVX2 where the `vpmaddubsw` instruction
// can encounter saturation when adding two signed 16-bit values into a
// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x i8`
// multiplication. By limiting the range of either the u8 or i8 input, we
// can avoid saturation. `(i16::MAX / 2).isqrt() == 127`, so if we ensure
// both int8 values are <= 127, saturation won't occur.
//
// This issue does not affect the VNNI instruction used on newer x64 systems.
struct ReducedRangeRng {
rng: XorShiftRng,
}

impl ReducedRangeRng {
fn new() -> Self {
Self {
rng: XorShiftRng::new(1234),
}
}

fn next_u8(&mut self) -> u8 {
(self.rng.next_u64() % 128) as u8
}
}

/// Test a GEMM kernel using all square matrices up to a given size, plus
/// various other "interesting" size combinations.
fn test_gemm_various_input_sizes<LhsT, RhsT, OutT>(
gemm: Option<&GemmExecutor<LhsT, RhsT, OutT>>,
mut lhs_gen: Option<&mut dyn FnMut() -> LhsT>,
) -> Result<(), Box<dyn Error>>
where
LhsT: GemmInT,
Expand Down Expand Up @@ -1586,7 +1617,11 @@ mod tests {

for (lhs_size, rhs_size) in cases {
let mut rng = XorShiftRng::new(1234);
let a = NdTensor::<LhsT, 2>::rand(lhs_size, &mut rng);
let a = if let Some(lhs_gen) = lhs_gen.as_mut() {
NdTensor::<LhsT, 2>::from_simple_fn(lhs_size, lhs_gen)
} else {
NdTensor::<LhsT, 2>::rand(lhs_size, &mut rng)
};
let b = NdTensor::<RhsT, 2>::rand(rhs_size, &mut rng);

let result = run_matmul(a.view(), b.view(), None, gemm).unwrap();
Expand All @@ -1608,47 +1643,49 @@ mod tests {
#[test]
fn test_gemm_with_fma_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::Fma).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

#[cfg(feature = "avx512")]
#[cfg(target_arch = "x86_64")]
#[test]
fn test_gemm_with_avx512_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::Avx512).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

#[cfg(target_arch = "aarch64")]
#[test]
fn test_gemm_with_arm_neon_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::ArmNeon).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

// This duplicates one of the other `test_gemm_with_XXX_kernel` tests
// depending on what the preferred kernel is. That's OK as long as this
// test is fast.
#[test]
fn test_gemm_with_auto_kernel() -> Result<(), Box<dyn Error>> {
test_gemm_various_input_sizes::<f32, f32, f32>(None)
test_gemm_various_input_sizes::<f32, f32, f32>(None, None)
}

#[test]
fn test_gemm_with_generic_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::Generic).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

#[test]
fn test_gemm_u8i8_i32() -> Result<(), Box<dyn Error>> {
let mut rng = ReducedRangeRng::new();
let gemm = GemmExecutor::<u8, i8, i32>::default();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), Some(&mut || rng.next_u8()))
}

#[test]
fn test_gemm_u8i8_i32_zero_point() {
let mut rng = XorShiftRng::new(1234);
let mut lhs_rng = ReducedRangeRng::new();
let mut rhs_rng = XorShiftRng::new(1234);

struct Case {
m: usize,
Expand All @@ -1664,8 +1701,8 @@ mod tests {
];

for Case { m, n, k } in cases {
let a = NdTensor::<u8, 2>::rand([m, k], &mut rng);
let b = NdTensor::<i8, 2>::rand([k, n], &mut rng);
let a = NdTensor::<u8, 2>::from_simple_fn([m, k], || lhs_rng.next_u8());
let b = NdTensor::<i8, 2>::rand([k, n], &mut rhs_rng);
let a_zero_point: Vec<_> = (0..a.rows() as u8).collect();
let b_zero_point: Vec<_> = (0..b.cols() as i8).collect();
let opts = Some(GemmOpts {
Expand Down
Loading

0 comments on commit 8fe2402

Please sign in to comment.