Skip to content

Commit

Permalink
selecting correct backend for MultiHeadAttention fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adobrzyniewicz-habana committed Dec 18, 2024
1 parent 88ef381 commit 32ec3e5
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,14 @@ def __init__(
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS

self.attn_backend = attn_backend if attn_backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
attn_backend_enum = backend_name_to_enum(attn_backend.get_name())

if attn_backend_enum in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:

Check failure on line 197 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/layer.py:197:81: E501 Line too long (83 > 80)
attn_backend_enum = _Backend.XFORMERS

self.attn_backend = attn_backend_enum if attn_backend_enum in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN
} else _Backend.TORCH_SDPA

def forward(
Expand Down Expand Up @@ -228,6 +231,15 @@ def forward(
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.HPU_ATTN:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2).contiguous()

return out.view(bsz, q_len, -1)


Expand Down

0 comments on commit 32ec3e5

Please sign in to comment.