diff --git a/inference_lib/src/aqlm/inference_kernels/kernel_selector.py b/inference_lib/src/aqlm/inference_kernels/kernel_selector.py index ea55f342..d72fe6ed 100644 --- a/inference_lib/src/aqlm/inference_kernels/kernel_selector.py +++ b/inference_lib/src/aqlm/inference_kernels/kernel_selector.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from aqlm.utils import _dequantize_weight, unpack_int_data +from .numba import numba_gemm_lut from .triton_kernel import triton_matmul @@ -27,6 +28,8 @@ def forward_pass_quantized_linear( return cuda_gemm_2x8(input, codes, codebooks, scales, bias) case (True, _, _, _, _): return triton_matmul(input, codes, codebooks, scales, bias) + case (False, _, 256, 1, _): + return numba_gemm_lut(input, codes, codebooks, scales, bias) case _: dequantized_weight = _dequantize_weight( unpack_int_data(codes, codebooks.shape[0].bit_length() - 1), diff --git a/inference_lib/src/aqlm/inference_kernels/numba.py b/inference_lib/src/aqlm/inference_kernels/numba.py new file mode 100644 index 00000000..aac6675b --- /dev/null +++ b/inference_lib/src/aqlm/inference_kernels/numba.py @@ -0,0 +1,63 @@ +from typing import Optional + +import numba +import numpy as np +import torch + +COMPILED_KERNELS = {} + + +def numba_gemm_lut( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + input_shape = input.shape + input = input.reshape(-1, input_shape[-1]) + + device, dtype = codebooks.device, codebooks.dtype + num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + in_features = input.shape[1] + out_features = codes.shape[0] * out_group_size + assert input.ndim == 2 + assert scales.shape == (out_features // out_group_size, 1, 1, 1) + assert in_features % in_group_size == 0 + assert codebook_size == 2**8 + assert codes.dtype == torch.int8 + assert input.dtype == torch.float32 and codebooks.dtype == torch.float32 + + kernel_key = (in_group_size, out_features, in_features, num_codebooks) + if kernel_key not in COMPILED_KERNELS: + + @numba.njit(nopython=True, parallel=False) + def numba_gemv_lut_(x, codebooks, codes_alt, scales): + lut = x.reshape(-1, in_group_size) @ codebooks.reshape(-1, in_group_size).T + lut = lut.reshape(-1, num_codebooks, codebook_size) + + output_vec = np.zeros(out_features, dtype=x.dtype) + for j in range(in_features // in_group_size): + for i in range(out_features): + for c in range(num_codebooks): + output_vec[i] += lut[j, c, codes_alt[j, i, c]] + output_vec *= scales.flatten() + return output_vec + + COMPILED_KERNELS[kernel_key] = numba_gemv_lut_ + compiled_kernel = COMPILED_KERNELS[kernel_key] + + output = torch.zeros(input.shape[0], out_features, device=device, dtype=dtype) + for i in range(input.shape[0]): + output[i] = torch.tensor( + compiled_kernel( + input[i].numpy(), + codebooks.numpy(), + torch.permute(codes, (1, 0, 2)).contiguous().numpy(), + scales.numpy(), + ) + ) + output *= scales.flatten().unsqueeze(0) + if bias is not None: + output += bias + return output.reshape(input_shape[:-1] + (-1,))