Skip to content

Commit

Permalink
Capabilities overhaul (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana authored Jan 20, 2025
1 parent cc069cb commit fedf706
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 17 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@ae726d4
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@01090a8
17 changes: 4 additions & 13 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)

Expand All @@ -17,18 +18,9 @@
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import is_fake_hpu

logger = init_logger(__name__)

HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")


class HPUAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -139,6 +131,7 @@ def __init__(
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
HPUFusedSDPA = kernels.fsdpa()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
Expand All @@ -151,9 +144,7 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags()
if self.prefill_use_fusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def forward_hpu(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.ops import HPUFusedRMSNorm
from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
Expand Down
6 changes: 4 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
Expand Down Expand Up @@ -214,8 +216,7 @@ class HpuModelAdapter:

def __init__(self, model, vllm_config, layer_names):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags()
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',
'false').lower() in ['1', 'true']
self.vllm_config = vllm_config
Expand Down Expand Up @@ -597,6 +598,7 @@ def __init__(
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
environment.set_model_config(self.model_config)
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states

Expand Down

0 comments on commit fedf706

Please sign in to comment.