diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index e50bf56674e03..b98722cc4a216 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -150,11 +150,11 @@ def __init__( assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' - suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.attn_type = attn_type if (self.attn_type != AttentionType.DECODER diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5eda6f40a05e6..8fb9b9aace15d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -834,7 +834,27 @@ def _attention_with_mask( attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. - if len(kv_cache.shape) > 1: + if is_hpu and kv_cache is not None and isinstance(kv_cache, tuple): + assert self.attn.backend == _Backend.HPU_ATTN + # During cross-attention decode, key & value will be None, + # we don't need to cache them. + if (k is not None) and (v is not None): + from vllm_hpu_extension.utils import VLLMKVCache + + from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + block_indices = attn_metadata.cross_block_indices + block_offsets = attn_metadata.cross_block_offsets + k_cache = VLLMKVCache() + v_cache = VLLMKVCache() + key_cache = k_cache(cached_k, key_cache, block_indices, + block_offsets) + value_cache = v_cache(cached_v, value_cache, block_indices, + block_offsets) + elif len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): @@ -889,14 +909,21 @@ def _attention_with_mask( kv_len, self.head_dim).contiguous() attention_mask = attention_mask.view(1, 1, q_len, kv_len) - output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - is_causal=False) - output = output.permute(2, 0, 1, 3).reshape( - q_len, self.num_local_heads * self.head_dim) - return output + if current_platform.is_hpu(): + from habana_frameworks.torch.hpex.kernels import FusedSDPA + output = FusedSDPA.apply(q, k, v, attention_mask) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output + else: + output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + is_causal=False) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output class MllamaCrossAttentionDecoderLayer(torch.nn.Module): @@ -961,9 +988,13 @@ def forward( # TODO: Change input_tokens tensor at the beginning of model execution # to 2D tensor to align with public vllm input_tokens shape. But this # will face the graph building failure issue, still need to investigate. - if len(hidden_states.shape) == 3: - full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( - hidden_states.size(0), -1, 1) + assert len(residual.shape) == 3 + if len(hidden_states.shape) == 2: + hidden_states = hidden_states.view(residual.size(0), + residual.size(1), + residual.size(2)) + full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( + hidden_states.size(0), -1, 1) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( ) * hidden_states @@ -1317,7 +1348,12 @@ def get_cross_attention_mask( num_tokens_per_tile: int, dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: - token_ids = input_ids.tolist() + token_ids = [] + if is_hpu: + # input_ids is not flatten yet for hpu + token_ids = input_ids.flatten().tolist() + else: + token_ids = input_ids.tolist() start = 0 batch_token_ids = [] for seq_len in attn_metadata.seq_lens: diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 13d6758a34976..5813fae391feb 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -8,6 +8,7 @@ import habana_frameworks.torch as htorch import torch +from PIL import Image from vllm_hpu_extension.ops import batch2block, block2batch from vllm.attention import AttentionMetadata @@ -20,7 +21,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import is_fake_hpu +from vllm.utils import is_fake_hpu, is_list_of from vllm.worker.hpu_model_runner import (HpuModelAdapter, HPUModelRunnerBase, ModelInputForHPUWithSamplingMetadata, setup_profiler, subtuple) @@ -357,9 +358,13 @@ def _prepare_encoder_model_input_tensors( return attn_metadata + @torch.inference_mode() def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + kv_caches = [ + torch.tensor([], dtype=torch.bfloat16, device=self.device) + for _ in range(num_layers) + ] max_batch_size = self.max_num_prefill_seqs _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_seq_len = min(self.max_num_batched_tokens // max_batch_size, @@ -447,10 +452,10 @@ def create_dummy_seq_group_metadata(self, cross_block_table: Optional[List[int]] = None encoder_dummy_data \ = self.input_registry.dummy_data_for_profiling( - self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) mm_counts = self.mm_registry.get_mm_limits_per_prompt( self.model_config) num_images = mm_counts["image"] @@ -471,18 +476,32 @@ def create_dummy_seq_group_metadata(self, max_mm_tokens) // self.block_size cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks prompt_token_ids = [0] * input_len + if is_prompt: + image_data = encoder_dummy_data.multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] + assert is_list_of(image_data, Image.Image) + text_prompt_len = input_len - 2 - len(image_data) + # for prompt like '<|image|><|image|><|begin_of_text|>...', token + # ids will be '128000 128256 128256 128000 ...' + prompt_token_ids = [128000] + [128256] * len(image_data) + [ + 128000 + ] + [0] * text_prompt_len output_token_ids = [1] * output_len prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 seq_data = SequenceData(prompt_token_ids_array) seq_data.output_token_ids = output_token_ids + return SequenceGroupMetadata( request_id=str(group_id), - is_prompt=(output_len == 0), + is_prompt=is_prompt, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=block_tables, encoder_seq_data=encoder_dummy_data.seq_data, multi_modal_data=encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=encoder_dummy_data. + multi_modal_placeholders, cross_block_table=cross_block_table) def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: