Skip to content

Commit

Permalink
Revert "vLLM-Ext: Full enabling of ALiBi (#34)" (#59)
Browse files Browse the repository at this point in the history
This reverts commit 0766759.
  • Loading branch information
michalkuligowski authored Dec 13, 2024
1 parent fffc2c0 commit 11b8d9d
Showing 1 changed file with 11 additions and 67 deletions.
78 changes: 11 additions & 67 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,24 +105,10 @@ def pa(attn, value, block_groups, block_mapping, block_scales, batch_size,
pa_impl = pipelined_pa if pipelined_pa_enabled else pa


def flat_pa(
query,
key_cache,
value_cache,
block_list,
block_mapping,
block_bias,
block_scales,
block_groups,
scale,
position_bias,
matmul_qk_op,
matmul_av_op,
batch2block_matmul_op,
block2batch_matmul_op,
keys_fetch_func,
values_fetch_func,
):
def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,
keys_fetch_func, values_fetch_func):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)
Expand All @@ -132,39 +118,20 @@ def flat_pa(
value = values_fetch_func(value_cache, block_list).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)
if kv_heads != q_heads:
block_bias = block_bias.unsqueeze(1)
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if position_bias is not None:
position_bias = position_bias.unflatten(1, (kv_heads, -1))
if block_bias is not None:
block_bias = block_bias.unsqueeze(2)
key = key.transpose(-2, -1)

attn = matmul_qk_op(query, key)
if position_bias is not None:
if attn.dtype != position_bias.dtype:
attn = attn.to(dtype=position_bias.dtype)
attn.add_(position_bias.unsqueeze(-2))
if block_bias is not None:
if attn.dtype != block_bias.dtype:
block_bias = block_bias.to(dtype=attn.dtype)
attn.add_(block_bias)

if attn.dtype != block_mapping.dtype:
block_mapping = block_mapping.to(dtype=attn.dtype)
if attn.dtype != block_scales.dtype:
block_scales = block_scales.to(dtype=attn.dtype)
if attn.dtype != value.dtype:
value = value.to(dtype=attn.dtype)
key = key.transpose(3, 4)
else:
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key) + block_bias
attn = pa_impl(attn, value, block_groups, block_mapping, block_scales=block_scales,
batch_size=batch_size, matmul_av_op=matmul_av_op,
batch2block_matmul_op=batch2block_matmul_op, block2batch_matmul_op=block2batch_matmul_op)
attn = block2batch(attn, block_mapping, block2batch_matmul_op)
if attn.dtype != query.dtype:
attn = attn.to(dtype=query.dtype)
attn = attn.squeeze(-2)

if kv_heads != q_heads:
attn = attn.flatten(1, 2)
return attn
Expand Down Expand Up @@ -196,8 +163,6 @@ def prompt_attention(
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
position_bias_offset: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
Expand All @@ -216,33 +181,13 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if position_bias is not None:
position_bias = position_bias.unflatten(1, (kv_heads, -1))
if position_bias_offset is not None:
position_bias_offset = position_bias_offset.unflatten(1, (kv_heads, -1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
key = key.transpose(-2, -1)
attn_weights = matmul_qk_op(query * scale, key)

if position_bias is not None:
if attn_weights.dtype != position_bias.dtype:
attn_weights = attn_weights.to(dtype=position_bias.dtype)
attn_weights.add_(position_bias)
if position_bias_offset is not None:
attn_weights.add_(position_bias_offset.unsqueeze(-1).unsqueeze(-1))
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)

attn_weights = softmax_op(attn_weights, dim=-1)
if attn_weights.dtype != value.dtype:
value = value.to(dtype=attn_weights.dtype)
attn_weights = matmul_av_op(attn_weights, value)
if attn_weights.dtype != query.dtype:
attn_weights = attn_weights.to(dtype=query.dtype)

if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
Expand All @@ -261,7 +206,6 @@ def prompt_attention(
attn_weights = fsdpa_op(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')

attn_weights = attn_weights.transpose(1, 2)
return attn_weights

Expand Down

0 comments on commit 11b8d9d

Please sign in to comment.