Skip to content

Commit

Permalink
moved cuda kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Feb 6, 2024
1 parent 78cc9a8 commit 1278164
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
1 change: 0 additions & 1 deletion inference_lib/src/aqlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
import aqlm.cuda
import aqlm.inference_kernels
from aqlm.inference import QuantizedLinear
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 6 additions & 5 deletions inference_lib/src/aqlm/inference_kernels/kernel_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import torch.nn.functional as F
from aqlm.utils import _dequantize_weight, unpack_int_data

from .numba import numba_gemm_lut
from .triton_kernel import triton_matmul


def forward_pass_quantized_linear(
input: torch.Tensor,
Expand All @@ -19,16 +16,20 @@ def forward_pass_quantized_linear(
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
match (input.is_cuda, num_codebooks, codebook_size, out_group_size, in_group_size):
case (True, 1, 65536, 1, 8):
from aqlm.cuda.cuda_kernel import cuda_gemm_1x16
from .cuda_kernel import cuda_gemm_1x16

return cuda_gemm_1x16(input, codes, codebooks, scales, bias)
case (True, 2, 256, 1, 8):
from aqlm.cuda.cuda_kernel import cuda_gemm_2x8
from .cuda_kernel import cuda_gemm_2x8

return cuda_gemm_2x8(input, codes, codebooks, scales, bias)
case (True, _, _, _, _):
from .triton_kernel import triton_matmul

return triton_matmul(input, codes, codebooks, scales, bias)
case (False, _, 256, 1, _):
from .numba import numba_gemm_lut

return numba_gemm_lut(input, codes, codebooks, scales, bias)
case _:
dequantized_weight = _dequantize_weight(
Expand Down

0 comments on commit 1278164

Please sign in to comment.