From 76bda2a5e9f21f4a29ae2c9f43624f53647842af Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Mon, 22 Jan 2024 13:33:38 +0300 Subject: [PATCH] faster triton kernel --- src/inference_kernels/triton_kernel.py | 10 ++++------ transformers/common/inference.py | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/inference_kernels/triton_kernel.py b/src/inference_kernels/triton_kernel.py index b3ca6143..c8185821 100644 --- a/src/inference_kernels/triton_kernel.py +++ b/src/inference_kernels/triton_kernel.py @@ -49,11 +49,11 @@ def _aqlm_gemv_simple( # Stage 1: load input data input_vec = tl.load( input_vec_ptr - + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size - + tl.arange(0, in_group_size)[None, None, :], - mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups, + + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size + + tl.arange(0, in_group_size)[None, None, None, :], + mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups, ) - # [in_features//in_group_size, 1, group_size] + # [in_features//in_group_size, 1, 1, group_size] # Note: we could simply load input_vec then reshape # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features] # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size]) @@ -98,8 +98,6 @@ def _aqlm_gemv_simple( weights_i = weights_i.to(tl.float32) input_vec = input_vec.to(tl.float32) # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size] - weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization - # ^-- [in_features // in_group_size, out_group_size, in_group_size] if out_group_size == 1: scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar diff --git a/transformers/common/inference.py b/transformers/common/inference.py index 86ba521d..47c84f17 100644 --- a/transformers/common/inference.py +++ b/transformers/common/inference.py @@ -188,11 +188,11 @@ def _aqlm_gemv_simple( # Stage 1: load input data input_vec = tl.load( input_vec_ptr - + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size - + tl.arange(0, in_group_size)[None, None, :], - mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups, + + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size + + tl.arange(0, in_group_size)[None, None, None, :], + mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups, ) - # [in_features//in_group_size, 1, group_size] + # [in_features//in_group_size, 1, 1, group_size] # Note: we could simply load input_vec then reshape # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features] # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size]) @@ -237,8 +237,6 @@ def _aqlm_gemv_simple( weights_i = weights_i.to(tl.float32) input_vec = input_vec.to(tl.float32) # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size] - weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization - # ^-- [in_features // in_group_size, out_group_size, in_group_size] if out_group_size == 1: scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar