Skip to content

Commit

Permalink
faster triton kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Jan 22, 2024
1 parent 0ebda08 commit 76bda2a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
10 changes: 4 additions & 6 deletions src/inference_kernels/triton_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def _aqlm_gemv_simple(
# Stage 1: load input data
input_vec = tl.load(
input_vec_ptr
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
+ tl.arange(0, in_group_size)[None, None, :],
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size
+ tl.arange(0, in_group_size)[None, None, None, :],
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups,
)
# [in_features//in_group_size, 1, group_size]
# [in_features//in_group_size, 1, 1, group_size]
# Note: we could simply load input_vec then reshape
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
Expand Down Expand Up @@ -98,8 +98,6 @@ def _aqlm_gemv_simple(
weights_i = weights_i.to(tl.float32)
input_vec = input_vec.to(tl.float32)
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
# ^-- [in_features // in_group_size, out_group_size, in_group_size]

if out_group_size == 1:
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
Expand Down
10 changes: 4 additions & 6 deletions transformers/common/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ def _aqlm_gemv_simple(
# Stage 1: load input data
input_vec = tl.load(
input_vec_ptr
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
+ tl.arange(0, in_group_size)[None, None, :],
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size
+ tl.arange(0, in_group_size)[None, None, None, :],
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups,
)
# [in_features//in_group_size, 1, group_size]
# [in_features//in_group_size, 1, 1, group_size]
# Note: we could simply load input_vec then reshape
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
Expand Down Expand Up @@ -237,8 +237,6 @@ def _aqlm_gemv_simple(
weights_i = weights_i.to(tl.float32)
input_vec = input_vec.to(tl.float32)
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
# ^-- [in_features // in_group_size, out_group_size, in_group_size]

if out_group_size == 1:
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
Expand Down

0 comments on commit 76bda2a

Please sign in to comment.