diff --git a/requirements-hpu.txt b/requirements-hpu.txt index ed6e57548edd5..ab4b823784bdc 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@ae726d4 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@01090a8 diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index cfb6ecd57181c..1893f98d8af77 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -2,12 +2,13 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### -import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch +import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, VLLMKVCache) @@ -17,18 +18,9 @@ from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import is_fake_hpu logger = init_logger(__name__) -HPUFusedSDPA = None -try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - HPUFusedSDPA = FusedSDPA -except ImportError: - logger.warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") - class HPUAttentionBackend(AttentionBackend): @@ -139,6 +131,7 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() + HPUFusedSDPA = kernels.fsdpa() self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads @@ -151,9 +144,7 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '1').lower() in ['1', 'true'] \ - and not is_fake_hpu() + self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags() if self.prefill_use_fusedsdpa: assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 43ea4eb5a4d1a..58e82884df7a1 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -104,7 +104,8 @@ def forward_hpu( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from vllm_hpu_extension.ops import HPUFusedRMSNorm + from vllm_hpu_extension.kernels import rms_norm + HPUFusedRMSNorm = rms_norm() if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 88699369116a3..fd0c40e803f54 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -19,7 +19,9 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch +import vllm_hpu_extension.environment as environment from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, @@ -214,8 +216,7 @@ class HpuModelAdapter: def __init__(self, model, vllm_config, layer_names): self.model = model - self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '0').lower() in ['1', 'true'] + self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags() self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] self.vllm_config = vllm_config @@ -597,6 +598,7 @@ def __init__( mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): ModelRunnerBase.__init__(self, vllm_config=vllm_config) + environment.set_model_config(self.model_config) self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states