Skip to content

Commit

Permalink
Implement int8 GEMM kernel for x86_64 using AVX2 instructions
Browse files Browse the repository at this point in the history
This uses the pre-VNNI instruction sequence for u8 x i8 -> i32 dot
products:

```
_mm256_maddubs_epi16
_mm256_madd_epi16
_mm256_add_epi32
```

The first instruction can saturate when adding pairs of intermediate 16-bit
signed ints. This can be avoided by limiting the range of the u8 LHS input.
  • Loading branch information
robertknight committed Jan 13, 2025
1 parent 8e612d7 commit fc61865
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 8 deletions.
37 changes: 29 additions & 8 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ impl Default for GemmExecutor<f32, f32, f32> {

impl Default for GemmExecutor<u8, i8, i32> {
fn default() -> Self {
#[cfg(target_arch = "x86_64")]
if let Some(gemm) = Self::from_kernel::<kernels::x86_64::Avx2Int8Kernel>(KernelType::Fma) {
return gemm;
}
Self::from_kernel::<GenericKernel>(KernelType::Generic).unwrap()
}
}
Expand Down Expand Up @@ -1543,6 +1547,7 @@ mod tests {
/// various other "interesting" size combinations.
fn test_gemm_various_input_sizes<LhsT, RhsT, OutT>(
gemm: Option<&GemmExecutor<LhsT, RhsT, OutT>>,
mut lhs_gen: Option<&mut dyn FnMut() -> LhsT>,
) -> Result<(), Box<dyn Error>>
where
LhsT: GemmInT,
Expand Down Expand Up @@ -1586,7 +1591,11 @@ mod tests {

for (lhs_size, rhs_size) in cases {
let mut rng = XorShiftRng::new(1234);
let a = NdTensor::<LhsT, 2>::rand(lhs_size, &mut rng);
let a = if let Some(lhs_gen) = lhs_gen.as_mut() {
NdTensor::<LhsT, 2>::from_simple_fn(lhs_size, lhs_gen)
} else {
NdTensor::<LhsT, 2>::rand(lhs_size, &mut rng)
};
let b = NdTensor::<RhsT, 2>::rand(rhs_size, &mut rng);

let result = run_matmul(a.view(), b.view(), None, gemm).unwrap();
Expand All @@ -1608,42 +1617,50 @@ mod tests {
#[test]
fn test_gemm_with_fma_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::Fma).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

#[cfg(feature = "avx512")]
#[cfg(target_arch = "x86_64")]
#[test]
fn test_gemm_with_avx512_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::Avx512).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

#[cfg(target_arch = "aarch64")]
#[test]
fn test_gemm_with_arm_neon_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::ArmNeon).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

// This duplicates one of the other `test_gemm_with_XXX_kernel` tests
// depending on what the preferred kernel is. That's OK as long as this
// test is fast.
#[test]
fn test_gemm_with_auto_kernel() -> Result<(), Box<dyn Error>> {
test_gemm_various_input_sizes::<f32, f32, f32>(None)
test_gemm_various_input_sizes::<f32, f32, f32>(None, None)
}

#[test]
fn test_gemm_with_generic_kernel() -> Result<(), Box<dyn Error>> {
let gemm = GemmExecutor::with_kernel(KernelType::Generic).unwrap();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(Some(&gemm), None)
}

#[test]
fn test_gemm_u8i8_i32() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let gemm = GemmExecutor::<u8, i8, i32>::default();
test_gemm_various_input_sizes(Some(&gemm))
test_gemm_various_input_sizes(
Some(&gemm),
Some(&mut || {
// To avoid saturation issues under AVX2 we restrict the LHS input
// to [0, 127].
(rng.next_u64() % 128) as u8
}),
)
}

#[test]
Expand All @@ -1664,7 +1681,11 @@ mod tests {
];

for Case { m, n, k } in cases {
let a = NdTensor::<u8, 2>::rand([m, k], &mut rng);
let a = NdTensor::<u8, 2>::from_simple_fn([m, k], || {
// To avoid saturation issues under AVX2 we restrict the LHS input
// to [0, 127].
(rng.next_u64() % 128) as u8
});
let b = NdTensor::<i8, 2>::rand([k, n], &mut rng);
let a_zero_point: Vec<_> = (0..a.rows() as u8).collect();
let b_zero_point: Vec<_> = (0..b.cols() as i8).collect();
Expand Down
Loading

0 comments on commit fc61865

Please sign in to comment.