Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-image support for llama3.2 #705

Open
wants to merge 9 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
f"Supported head sizes are: {supported_head_sizes}.")

self.attn_type = attn_type
if (self.attn_type != AttentionType.DECODER
Expand Down
62 changes: 49 additions & 13 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,27 @@ def _attention_with_mask(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
if is_hpu and kv_cache is not None and isinstance(kv_cache, tuple):
assert self.attn.backend == _Backend.HPU_ATTN
# During cross-attention decode, key & value will be None,
# we don't need to cache them.
if (k is not None) and (v is not None):
from vllm_hpu_extension.utils import VLLMKVCache

from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
block_indices = attn_metadata.cross_block_indices
block_offsets = attn_metadata.cross_block_offsets
k_cache = VLLMKVCache()
v_cache = VLLMKVCache()
key_cache = k_cache(cached_k, key_cache, block_indices,
block_offsets)
value_cache = v_cache(cached_v, value_cache, block_indices,
block_offsets)
elif len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
Expand Down Expand Up @@ -889,14 +909,21 @@ def _attention_with_mask(
kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
if current_platform.is_hpu():
from habana_frameworks.torch.hpex.kernels import FusedSDPA
output = FusedSDPA.apply(q, k, v, attention_mask)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
else:
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output


class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
Expand Down Expand Up @@ -961,9 +988,13 @@ def forward(
# TODO: Change input_tokens tensor at the beginning of model execution
# to 2D tensor to align with public vllm input_tokens shape. But this
# will face the graph building failure issue, still need to investigate.
if len(hidden_states.shape) == 3:
full_text_row_masked_out_mask = full_text_row_masked_out_mask.view(
hidden_states.size(0), -1, 1)
assert len(residual.shape) == 3
if len(hidden_states.shape) == 2:
hidden_states = hidden_states.view(residual.size(0),
residual.size(1),
residual.size(2))
full_text_row_masked_out_mask = full_text_row_masked_out_mask.view(
hidden_states.size(0), -1, 1)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
) * hidden_states
Expand Down Expand Up @@ -1317,7 +1348,12 @@ def get_cross_attention_mask(
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
token_ids = []
if is_hpu:
# input_ids is not flatten yet for hpu
token_ids = input_ids.flatten().tolist()
else:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
Expand Down
33 changes: 26 additions & 7 deletions vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import habana_frameworks.torch as htorch
import torch
from PIL import Image
from vllm_hpu_extension.ops import batch2block, block2batch

from vllm.attention import AttentionMetadata
Expand All @@ -20,7 +21,7 @@
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import is_fake_hpu
from vllm.utils import is_fake_hpu, is_list_of
from vllm.worker.hpu_model_runner import (HpuModelAdapter, HPUModelRunnerBase,
ModelInputForHPUWithSamplingMetadata,
setup_profiler, subtuple)
Expand Down Expand Up @@ -357,9 +358,13 @@ def _prepare_encoder_model_input_tensors(

return attn_metadata

@torch.inference_mode()
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
kv_caches = [
torch.tensor([], dtype=torch.bfloat16, device=self.device)
for _ in range(num_layers)
]
max_batch_size = self.max_num_prefill_seqs
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_seq_len = min(self.max_num_batched_tokens // max_batch_size,
Expand Down Expand Up @@ -447,10 +452,10 @@ def create_dummy_seq_group_metadata(self,
cross_block_table: Optional[List[int]] = None
encoder_dummy_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
mm_counts = self.mm_registry.get_mm_limits_per_prompt(
self.model_config)
num_images = mm_counts["image"]
Expand All @@ -471,18 +476,32 @@ def create_dummy_seq_group_metadata(self,
max_mm_tokens) // self.block_size
cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks
prompt_token_ids = [0] * input_len
if is_prompt:
image_data = encoder_dummy_data.multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
assert is_list_of(image_data, Image.Image)
text_prompt_len = input_len - 2 - len(image_data)
# for prompt like '<|image|><|image|><|begin_of_text|>...', token
# ids will be '128000 128256 128256 128000 ...'
prompt_token_ids = [128000] + [128256] * len(image_data) + [
kdamaszk marked this conversation as resolved.
Show resolved Hide resolved
128000
] + [0] * text_prompt_len
output_token_ids = [1] * output_len
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
seq_data = SequenceData(prompt_token_ids_array)
seq_data.output_token_ids = output_token_ids

return SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=(output_len == 0),
is_prompt=is_prompt,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=block_tables,
encoder_seq_data=encoder_dummy_data.seq_data,
multi_modal_data=encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=encoder_dummy_data.
multi_modal_placeholders,
cross_block_table=cross_block_table)

def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
Expand Down
Loading