Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: selecting correct backend for MultiHeadAttention #645

Merged
merged 18 commits into from
Feb 3, 2025
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def __init__(
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN
} else _Backend.TORCH_SDPA

def forward(
Expand Down Expand Up @@ -279,6 +278,36 @@ def forward(
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.HPU_ATTN:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))

from vllm_hpu_extension.flags import enabled_flags

if "fsdpa" in enabled_flags():
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA

fsdpa_op = ModuleFusedSDPA(FusedSDPA)

out = fsdpa_op(query,
key,
value,
None,
dropout_p=0.0,
is_causal=False,
scale=self.scale,
softmax_mode="fast",
recompute_mode=True,
valid_sequence_lengths=None)
else:
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)

out = out.transpose(1, 2)

return out.reshape(bsz, q_len, -1)


Expand Down
Loading