diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index b4c595ca..0cf088cd 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -111,7 +111,8 @@ def scatter_reduce(attn, batch_size, block_groups, **rest): DEFAULT_PA_SOFTMAX_IMPL = "index_reduce" if "index_reduce" in capabilities() else "wsum_head_amax" -normalize = SoftmaxNormalization(os.environ.get('VLLM_PA_SOFTMAX_IMPL', DEFAULT_PA_SOFTMAX_IMPL).split(',')) +ACTUAL_PA_SOFTMAX_IMPL = os.environ.get('VLLM_PA_SOFTMAX_IMPL', DEFAULT_PA_SOFTMAX_IMPL) +normalize = SoftmaxNormalization(ACTUAL_PA_SOFTMAX_IMPL.split(',')) def b2b_impl(tensor, block_mapping, matmul_op): @@ -141,10 +142,24 @@ def block_softmax(batch_size, attn, block_mapping, block_scales, block_groups): return attn -def flat_pa(query, key_cache, value_cache, block_list, block_mapping, - block_bias, block_scales, block_groups, scale, matmul_qk_op, - matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, - keys_fetch_func, values_fetch_func): +def flat_pa( + query, + key_cache, + value_cache, + block_list, + block_mapping, + block_bias, + block_scales, + block_groups, + scale, + position_bias, + matmul_qk_op, + matmul_av_op, + batch2block_matmul_op, + block2batch_matmul_op, + keys_fetch_func, + values_fetch_func, +): batch_size = query.size(0) q_heads = query.size(1) kv_heads = key_cache.size(2) @@ -162,10 +177,17 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, else: key = key.transpose(2, 3) - attn = matmul_qk_op(query, key) + block_bias - attn = block_softmax(batch_size, attn, block_mapping, block_scales, block_groups) - attn = matmul_av_op(attn, value) - attn = block2batch(attn, block_mapping, block2batch_matmul_op) + attn = matmul_qk_op(query, key) + attn_orig_dtype = attn.dtype + if position_bias is not None: + attn = attn.to(dtype=position_bias.dtype) + attn.add_(position_bias.unsqueeze(-2)[:attn.shape[0]]) + if block_bias is not None: + attn = attn + block_bias.to(dtype=attn.dtype) + attn = block_softmax(batch_size, attn, block_mapping.to(dtype=attn.dtype), block_scales.to(dtype=attn.dtype), block_groups) + attn = matmul_av_op(attn, value.to(dtype=attn.dtype)) + attn = block2batch(attn, block_mapping.to(dtype=attn.dtype), block2batch_matmul_op) + attn = attn.to(dtype=attn_orig_dtype) attn = attn.squeeze(-2) if kv_heads != q_heads: attn = attn.flatten(1, 2) @@ -182,6 +204,7 @@ def prompt_attention( key: torch.Tensor, value: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, matmul_qk_op=torch.matmul, @@ -199,13 +222,20 @@ def prompt_attention( query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) + if position_bias is not None: + position_bias = position_bias.unsqueeze(2) if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) + attn_orig_dtype = attn_weights.dtype + if position_bias is not None: + attn_weights = attn_weights.to(dtype=position_bias.dtype) + attn_weights.add_(position_bias) if attn_bias is not None: - attn_weights.add_(attn_bias) + attn_weights.add_(attn_bias.to(dtype=attn_weights.dtype)) attn_weights = softmax_op(attn_weights, dim=-1) - attn_weights = matmul_av_op(attn_weights, value) + attn_weights = matmul_av_op(attn_weights, value.to(dtype=attn_weights.dtype)) + attn_weights = attn_weights.to(dtype=attn_orig_dtype) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) else: