Skip to content

Commit

Permalink
vLLM-Ext: Resolved ALiBI bias regression
Browse files Browse the repository at this point in the history
Changes:
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
  - Due to bloom's 176B parameter size I was unable to test this model.
    Its changes are the simplest though.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false",
  "VLLM_CONTIGUOUS_PA=false", "VLLM_PA_SOFTMAX_IMPL=wsum_head_amax".

Reamining TODO:
- Resolve quality issue when running prompts of significantly different
  lengths.
- Resolve issue with contiguous PA.
- Integrate support for GQA along with MHA.

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Nov 27, 2024
1 parent ac9740d commit c0fb257
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def scatter_reduce(attn, batch_size, block_groups, **rest):


DEFAULT_PA_SOFTMAX_IMPL = "index_reduce" if "index_reduce" in capabilities() else "wsum_head_amax"
normalize = SoftmaxNormalization(os.environ.get('VLLM_PA_SOFTMAX_IMPL', DEFAULT_PA_SOFTMAX_IMPL).split(','))
ACTUAL_PA_SOFTMAX_IMPL = os.environ.get('VLLM_PA_SOFTMAX_IMPL', DEFAULT_PA_SOFTMAX_IMPL)
normalize = SoftmaxNormalization(ACTUAL_PA_SOFTMAX_IMPL.split(','))


def b2b_impl(tensor, block_mapping, matmul_op):
Expand Down Expand Up @@ -141,10 +142,24 @@ def block_softmax(batch_size, attn, block_mapping, block_scales, block_groups):
return attn


def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,
keys_fetch_func, values_fetch_func):
def flat_pa(
query,
key_cache,
value_cache,
block_list,
block_mapping,
block_bias,
block_scales,
block_groups,
scale,
position_bias,
matmul_qk_op,
matmul_av_op,
batch2block_matmul_op,
block2batch_matmul_op,
keys_fetch_func,
values_fetch_func,
):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)
Expand All @@ -162,10 +177,17 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
else:
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key) + block_bias
attn = block_softmax(batch_size, attn, block_mapping, block_scales, block_groups)
attn = matmul_av_op(attn, value)
attn = block2batch(attn, block_mapping, block2batch_matmul_op)
attn = matmul_qk_op(query, key)
attn_orig_dtype = attn.dtype
if position_bias is not None:
attn = attn.to(dtype=position_bias.dtype)
attn.add_(position_bias.unsqueeze(-2)[:attn.shape[0]])
if block_bias is not None:
attn = attn + block_bias.to(dtype=attn.dtype)
attn = block_softmax(batch_size, attn, block_mapping.to(dtype=attn.dtype), block_scales.to(dtype=attn.dtype), block_groups)
attn = matmul_av_op(attn, value.to(dtype=attn.dtype))
attn = block2batch(attn, block_mapping.to(dtype=attn.dtype), block2batch_matmul_op)
attn = attn.to(dtype=attn_orig_dtype)
attn = attn.squeeze(-2)
if kv_heads != q_heads:
attn = attn.flatten(1, 2)
Expand All @@ -182,6 +204,7 @@ def prompt_attention(
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
Expand All @@ -199,13 +222,20 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if position_bias is not None:
position_bias = position_bias.unsqueeze(2)
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
attn_orig_dtype = attn_weights.dtype
if position_bias is not None:
attn_weights = attn_weights.to(dtype=position_bias.dtype)
attn_weights.add_(position_bias)
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights.add_(attn_bias.to(dtype=attn_weights.dtype))
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
attn_weights = matmul_av_op(attn_weights, value.to(dtype=attn_weights.dtype))
attn_weights = attn_weights.to(dtype=attn_orig_dtype)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
Expand Down

0 comments on commit c0fb257

Please sign in to comment.