From 11b8d9deba62d3d519ce407bf7c22b2aecb57935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Kuligowski?= Date: Fri, 13 Dec 2024 15:46:29 +0100 Subject: [PATCH] Revert "vLLM-Ext: Full enabling of ALiBi (#34)" (#59) This reverts commit 07667596f5cb8c8540ef9588ce65bbb311363b66. --- vllm_hpu_extension/ops.py | 78 ++++++--------------------------------- 1 file changed, 11 insertions(+), 67 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index f4e98a6a..e6726268 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -105,24 +105,10 @@ def pa(attn, value, block_groups, block_mapping, block_scales, batch_size, pa_impl = pipelined_pa if pipelined_pa_enabled else pa -def flat_pa( - query, - key_cache, - value_cache, - block_list, - block_mapping, - block_bias, - block_scales, - block_groups, - scale, - position_bias, - matmul_qk_op, - matmul_av_op, - batch2block_matmul_op, - block2batch_matmul_op, - keys_fetch_func, - values_fetch_func, -): +def flat_pa(query, key_cache, value_cache, block_list, block_mapping, + block_bias, block_scales, block_groups, scale, matmul_qk_op, + matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, + keys_fetch_func, values_fetch_func): batch_size = query.size(0) q_heads = query.size(1) kv_heads = key_cache.size(2) @@ -132,39 +118,20 @@ def flat_pa( value = values_fetch_func(value_cache, block_list).transpose(1, 2) block_bias = block_bias.view(key.size(0), 1, 1, -1) if kv_heads != q_heads: + block_bias = block_bias.unsqueeze(1) query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) - if position_bias is not None: - position_bias = position_bias.unflatten(1, (kv_heads, -1)) - if block_bias is not None: - block_bias = block_bias.unsqueeze(2) - key = key.transpose(-2, -1) - - attn = matmul_qk_op(query, key) - if position_bias is not None: - if attn.dtype != position_bias.dtype: - attn = attn.to(dtype=position_bias.dtype) - attn.add_(position_bias.unsqueeze(-2)) - if block_bias is not None: - if attn.dtype != block_bias.dtype: - block_bias = block_bias.to(dtype=attn.dtype) - attn.add_(block_bias) - - if attn.dtype != block_mapping.dtype: - block_mapping = block_mapping.to(dtype=attn.dtype) - if attn.dtype != block_scales.dtype: - block_scales = block_scales.to(dtype=attn.dtype) - if attn.dtype != value.dtype: - value = value.to(dtype=attn.dtype) + key = key.transpose(3, 4) + else: + key = key.transpose(2, 3) + + attn = matmul_qk_op(query, key) + block_bias attn = pa_impl(attn, value, block_groups, block_mapping, block_scales=block_scales, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, block2batch_matmul_op=block2batch_matmul_op) attn = block2batch(attn, block_mapping, block2batch_matmul_op) - if attn.dtype != query.dtype: - attn = attn.to(dtype=query.dtype) attn = attn.squeeze(-2) - if kv_heads != q_heads: attn = attn.flatten(1, 2) return attn @@ -196,8 +163,6 @@ def prompt_attention( key: torch.Tensor, value: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, - position_bias: Optional[torch.Tensor] = None, - position_bias_offset: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, matmul_qk_op=torch.matmul, @@ -216,33 +181,13 @@ def prompt_attention( query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) - if position_bias is not None: - position_bias = position_bias.unflatten(1, (kv_heads, -1)) - if position_bias_offset is not None: - position_bias_offset = position_bias_offset.unflatten(1, (kv_heads, -1)) if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) - key = key.transpose(-2, -1) - attn_weights = matmul_qk_op(query * scale, key) - - if position_bias is not None: - if attn_weights.dtype != position_bias.dtype: - attn_weights = attn_weights.to(dtype=position_bias.dtype) - attn_weights.add_(position_bias) - if position_bias_offset is not None: - attn_weights.add_(position_bias_offset.unsqueeze(-1).unsqueeze(-1)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: - if attn_weights.dtype != attn_bias.dtype: - attn_bias = attn_bias.to(dtype=attn_weights.dtype) attn_weights.add_(attn_bias) - attn_weights = softmax_op(attn_weights, dim=-1) - if attn_weights.dtype != value.dtype: - value = value.to(dtype=attn_weights.dtype) attn_weights = matmul_av_op(attn_weights, value) - if attn_weights.dtype != query.dtype: - attn_weights = attn_weights.to(dtype=query.dtype) - if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) else: @@ -261,7 +206,6 @@ def prompt_attention( attn_weights = fsdpa_op(query, key, value, None, 0.0, True, scale, softmax_mode, recompute_mode, valid_seq_lengths, 'right') - attn_weights = attn_weights.transpose(1, 2) return attn_weights