From 5310f9e4e13c1ed68d93322d890b5fda69eaa311 Mon Sep 17 00:00:00 2001 From: yan ma Date: Mon, 20 Jan 2025 23:11:45 +0800 Subject: [PATCH 1/7] multi-image support for llama3.2 Signed-off-by: yan ma --- vllm/model_executor/models/mllama.py | 48 ++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5eda6f40a05e6..6b94c6f7aaa09 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -834,7 +834,26 @@ def _attention_with_mask( attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. - if len(kv_cache.shape) > 1: + if is_hpu and kv_cache is not None: + assert self.attn.backend == _Backend.HPU_ATTN + # During cross-attention decode, key & value will be None, + # we don't need to cache them. + if (k is not None) and (v is not None): + from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention + from vllm_hpu_extension.utils import VLLMKVCache + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + block_indices = attn_metadata.cross_block_indices + block_offsets = attn_metadata.cross_block_offsets + k_cache = VLLMKVCache() + v_cache = VLLMKVCache() + key_cache = k_cache(cached_k, key_cache, block_indices, + block_offsets) + value_cache = v_cache(cached_v, value_cache, block_indices, + block_offsets) + elif len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): @@ -889,14 +908,18 @@ def _attention_with_mask( kv_len, self.head_dim).contiguous() attention_mask = attention_mask.view(1, 1, q_len, kv_len) - output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - is_causal=False) - output = output.permute(2, 0, 1, 3).reshape( - q_len, self.num_local_heads * self.head_dim) - return output + if current_platform.is_hpu(): + from habana_frameworks.torch.hpex.kernels import FusedSDPA + output = FusedSDPA.apply(q, k, v, attention_mask, 0.0) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output + else: + output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=0.0) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output class MllamaCrossAttentionDecoderLayer(torch.nn.Module): @@ -1317,7 +1340,12 @@ def get_cross_attention_mask( num_tokens_per_tile: int, dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: - token_ids = input_ids.tolist() + token_ids = [] + if is_hpu: + # input_ids is not flatten yet for hpu + token_ids = input_ids.flatten().tolist() + else: + token_ids = input_ids.tolist() start = 0 batch_token_ids = [] for seq_len in attn_metadata.seq_lens: From cdc00d5c984eac39783bca98674cac0d37069477 Mon Sep 17 00:00:00 2001 From: yan ma Date: Tue, 21 Jan 2025 04:17:45 +0800 Subject: [PATCH 2/7] fix Signed-off-by: yan ma --- vllm/attention/backends/hpu_attn.py | 6 +++--- vllm/model_executor/models/mllama.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index e50bf56674e03..b98722cc4a216 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -150,11 +150,11 @@ def __init__( assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' - suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.attn_type = attn_type if (self.attn_type != AttentionType.DECODER diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6b94c6f7aaa09..721ba1c8faa55 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -834,13 +834,14 @@ def _attention_with_mask( attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. - if is_hpu and kv_cache is not None: + if is_hpu and kv_cache is not None and isinstance(kv_cache, tuple): assert self.attn.backend == _Backend.HPU_ATTN # During cross-attention decode, key & value will be None, # we don't need to cache them. if (k is not None) and (v is not None): - from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention from vllm_hpu_extension.utils import VLLMKVCache + + from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -850,9 +851,9 @@ def _attention_with_mask( k_cache = VLLMKVCache() v_cache = VLLMKVCache() key_cache = k_cache(cached_k, key_cache, block_indices, - block_offsets) + block_offsets) value_cache = v_cache(cached_v, value_cache, block_indices, - block_offsets) + block_offsets) elif len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, @@ -915,8 +916,11 @@ def _attention_with_mask( q_len, self.num_local_heads * self.head_dim) return output else: - output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, dropout_p=0.0) + output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0) output = output.permute(2, 0, 1, 3).reshape( q_len, self.num_local_heads * self.head_dim) return output From 5154ee7aacb8097085e3f2624ae2cb08e939ca9e Mon Sep 17 00:00:00 2001 From: yan ma Date: Wed, 22 Jan 2025 18:54:25 +0800 Subject: [PATCH 3/7] fix --- vllm/model_executor/models/mllama.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 721ba1c8faa55..d05e78bf0532a 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -911,7 +911,7 @@ def _attention_with_mask( attention_mask = attention_mask.view(1, 1, q_len, kv_len) if current_platform.is_hpu(): from habana_frameworks.torch.hpex.kernels import FusedSDPA - output = FusedSDPA.apply(q, k, v, attention_mask, 0.0) + output = FusedSDPA.apply(q, k, v, attention_mask) output = output.permute(2, 0, 1, 3).reshape( q_len, self.num_local_heads * self.head_dim) return output @@ -920,7 +920,7 @@ def _attention_with_mask( k, v, attn_mask=attention_mask, - dropout_p=0.0) + is_causal=False) output = output.permute(2, 0, 1, 3).reshape( q_len, self.num_local_heads * self.head_dim) return output @@ -988,8 +988,10 @@ def forward( # TODO: Change input_tokens tensor at the beginning of model execution # to 2D tensor to align with public vllm input_tokens shape. But this # will face the graph building failure issue, still need to investigate. - if len(hidden_states.shape) == 3: - full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( + assert len(residual.shape) == 3 + if len(hidden_states.shape)==2: + hidden_states = hidden_states.view(residual.size(0), residual.size(1), residual.size(2)) + full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( hidden_states.size(0), -1, 1) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( From bfc49fdf8e757cc1292d893be2eaafa2fffa5147 Mon Sep 17 00:00:00 2001 From: yan ma Date: Fri, 24 Jan 2025 17:34:18 +0800 Subject: [PATCH 4/7] fix profile_run Signed-off-by: yan ma --- vllm/worker/hpu_enc_dec_model_runner.py | 198 +++++++++--------------- 1 file changed, 74 insertions(+), 124 deletions(-) diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 13d6758a34976..a0b1848ee611f 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -358,132 +358,82 @@ def _prepare_encoder_model_input_tensors( return attn_metadata def profile_run(self) -> None: - num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers - max_batch_size = self.max_num_prefill_seqs - _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - max_seq_len = min(self.max_num_batched_tokens // max_batch_size, - max_seq_len) - - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, - False) - return + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] - def warmup_scenario(self, - batch_size, - seq_len, - is_prompt, - kv_caches, - is_pt_profiler_run=False, - is_lora_profile_run=False, - temperature=0) -> None: - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - scenario_name = ("warmup_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") - self.profiler.start('internal', scenario_name) - times = 3 if use_graphs or is_pt_profiler_run else 1 - if is_prompt: - seqs = [ - self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) - for i in range(batch_size) - ] - else: - # FIXME: seq_len is actually number of blocks - blocks = [seq_len // batch_size for _ in range(batch_size)] - blocks[0] += seq_len % batch_size - seqs = [ - self.create_dummy_seq_group_metadata(i, - b * self.block_size - 1, - is_prompt) - for i, b in enumerate(blocks) - ] - torch.hpu.synchronize() - profiler = None - if is_pt_profiler_run and self.is_driver_worker: - profiler = setup_profiler() - profiler.start() - for _ in range(times): - inputs = self.prepare_model_input(seqs) - is_single_step = \ - self.vllm_config.scheduler_config.num_scheduler_steps == 1 - if is_prompt or is_single_step: - self.execute_model(inputs, kv_caches, warmup_mode=True) - else: # decode with multi-step - inputs = dataclasses.replace(inputs, - is_first_multi_step=True, - is_last_step=False) - self.execute_model(inputs, - kv_caches, - warmup_mode=True, - num_steps=2, - seqs=seqs) - inputs = dataclasses.replace(inputs, - is_first_multi_step=False, - is_last_step=True) - self.execute_model(inputs, - kv_caches, - warmup_mode=True, - num_steps=2, - seqs=seqs) - torch.hpu.synchronize() - if profiler: - profiler.step() - if profiler: - profiler.stop() - self.profiler.end() - gc.collect() - - def create_dummy_seq_group_metadata(self, - group_id, - seq_len, - is_prompt, - lora_request=None, - temperature=0): - sampling_params = SamplingParams(temperature=temperature) - num_blocks = math.ceil(seq_len / self.block_size) - cross_block_table: Optional[List[int]] = None - encoder_dummy_data \ - = self.input_registry.dummy_data_for_profiling( - self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) - mm_counts = self.mm_registry.get_mm_limits_per_prompt( - self.model_config) - num_images = mm_counts["image"] max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) * num_images - seq_len = max(seq_len, 1) - if is_prompt: - input_len = seq_len - output_len = 0 - block_tables = None - cross_block_table = None - else: - input_len = seq_len - 1 - output_len = 1 - block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} - # limit cross blocks to the number of available blocks - num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks, - max_mm_tokens) // self.block_size - cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks - prompt_token_ids = [0] * input_len - output_token_ids = [1] * output_len - prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 - seq_data = SequenceData(prompt_token_ids_array) - seq_data.output_token_ids = output_token_ids - return SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=(output_len == 0), - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=block_tables, - encoder_seq_data=encoder_dummy_data.seq_data, - multi_modal_data=encoder_dummy_data.multi_modal_data, - cross_block_table=cross_block_table) + self.model_config) + if max_mm_tokens > 0: + logger.info("Starting profile run for multi-modal models.") + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + decoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=False) + encoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) + + # Having more tokens is over-conservative but otherwise fine + assert len( + decoder_dummy_data.seq_data.prompt_token_ids + ) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" + ) + + assert decoder_dummy_data.multi_modal_data is None or \ + encoder_dummy_data.multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: decoder_dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + encoder_seq_data=encoder_dummy_data.seq_data, + cross_block_table=None, + multi_modal_data=decoder_dummy_data.multi_modal_data + or encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=decoder_dummy_data. + multi_modal_placeholders + or encoder_dummy_data.multi_modal_placeholders) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + return def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: From 6bd06d7d11fe3fbb938aa6879e6cefc8e53184f2 Mon Sep 17 00:00:00 2001 From: yan ma Date: Sat, 25 Jan 2025 00:27:45 +0800 Subject: [PATCH 5/7] fix --- vllm/worker/hpu_enc_dec_model_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index a0b1848ee611f..6001178e1ea4a 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -361,7 +361,8 @@ def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs + # Workaround to avoid unexpeced OOM failure during profile run + max_num_seqs = int(self.scheduler_config.max_num_seqs/2) # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. @@ -432,7 +433,8 @@ def profile_run(self) -> None: seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() + torch.hpu.synchronize() + gc.collect() return def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: From d8f816591faeda21a6188f1ada5490b3d0c08647 Mon Sep 17 00:00:00 2001 From: yan ma Date: Sun, 26 Jan 2025 17:43:01 +0800 Subject: [PATCH 6/7] revert --- vllm/model_executor/models/mllama.py | 8 +- vllm/worker/hpu_enc_dec_model_runner.py | 211 ++++++++++++++++-------- 2 files changed, 143 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d05e78bf0532a..8fb9b9aace15d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -989,10 +989,12 @@ def forward( # to 2D tensor to align with public vllm input_tokens shape. But this # will face the graph building failure issue, still need to investigate. assert len(residual.shape) == 3 - if len(hidden_states.shape)==2: - hidden_states = hidden_states.view(residual.size(0), residual.size(1), residual.size(2)) + if len(hidden_states.shape) == 2: + hidden_states = hidden_states.view(residual.size(0), + residual.size(1), + residual.size(2)) full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( - hidden_states.size(0), -1, 1) + hidden_states.size(0), -1, 1) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( ) * hidden_states diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 6001178e1ea4a..66ede172cf694 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -8,6 +8,7 @@ import habana_frameworks.torch as htorch import torch +from PIL import Image from vllm_hpu_extension.ops import batch2block, block2batch from vllm.attention import AttentionMetadata @@ -20,7 +21,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import is_fake_hpu +from vllm.utils import is_fake_hpu, is_list_of from vllm.worker.hpu_model_runner import (HpuModelAdapter, HPUModelRunnerBase, ModelInputForHPUWithSamplingMetadata, setup_profiler, subtuple) @@ -357,85 +358,149 @@ def _prepare_encoder_model_input_tensors( return attn_metadata + @torch.inference_mode() def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - # Workaround to avoid unexpeced OOM failure during profile run - max_num_seqs = int(self.scheduler_config.max_num_seqs/2) - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - logger.info("Starting profile run for multi-modal models.") - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - decoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=False) - encoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) - - # Having more tokens is over-conservative but otherwise fine - assert len( - decoder_dummy_data.seq_data.prompt_token_ids - ) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" - ) - - assert decoder_dummy_data.multi_modal_data is None or \ - encoder_dummy_data.multi_modal_data is None, ( - "Multi-modal data can't be provided in both encoder and decoder" - ) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: decoder_dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - encoder_seq_data=encoder_dummy_data.seq_data, - cross_block_table=None, - multi_modal_data=decoder_dummy_data.multi_modal_data - or encoder_dummy_data.multi_modal_data, - multi_modal_placeholders=decoder_dummy_data. - multi_modal_placeholders - or encoder_dummy_data.multi_modal_placeholders) - seqs.append(seq) - - # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) + torch.tensor([], dtype=torch.bfloat16, device=self.device) for _ in range(num_layers) ] - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - self.execute_model(model_input, kv_caches, intermediate_tensors) + max_batch_size = self.max_num_prefill_seqs + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_seq_len = min(self.max_num_batched_tokens // max_batch_size, + max_seq_len) + + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, + False) + return + + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + kv_caches, + is_pt_profiler_run=False, + is_lora_profile_run=False, + temperature=0) -> None: + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + scenario_name = ("warmup_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + self.profiler.start('internal', scenario_name) + times = 3 if use_graphs or is_pt_profiler_run else 1 + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata(i, + b * self.block_size - 1, + is_prompt) + for i, b in enumerate(blocks) + ] torch.hpu.synchronize() + profiler = None + if is_pt_profiler_run and self.is_driver_worker: + profiler = setup_profiler() + profiler.start() + for _ in range(times): + inputs = self.prepare_model_input(seqs) + is_single_step = \ + self.vllm_config.scheduler_config.num_scheduler_steps == 1 + if is_prompt or is_single_step: + self.execute_model(inputs, kv_caches, warmup_mode=True) + else: # decode with multi-step + inputs = dataclasses.replace(inputs, + is_first_multi_step=True, + is_last_step=False) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) + inputs = dataclasses.replace(inputs, + is_first_multi_step=False, + is_last_step=True) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) + torch.hpu.synchronize() + if profiler: + profiler.step() + if profiler: + profiler.stop() + self.profiler.end() gc.collect() - return + + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None, + temperature=0): + sampling_params = SamplingParams(temperature=temperature) + num_blocks = math.ceil(seq_len / self.block_size) + cross_block_table: Optional[List[int]] = None + encoder_dummy_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) + mm_counts = self.mm_registry.get_mm_limits_per_prompt( + self.model_config) + num_images = mm_counts["image"] + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) * num_images + seq_len = max(seq_len, 1) + if is_prompt: + input_len = seq_len + output_len = 0 + block_tables = None + cross_block_table = None + else: + input_len = seq_len - 1 + output_len = 1 + block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + # limit cross blocks to the number of available blocks + num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks, + max_mm_tokens) // self.block_size + cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks + prompt_token_ids = [0] * input_len + if is_prompt: + image_data = encoder_dummy_data.multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] + assert is_list_of(image_data, Image.Image) + text_prompt_len = input_len - 2 - len(image_data) + prompt_token_ids = [128000] + [128256] * len(image_data) + [ + 128000 + ] + [0] * text_prompt_len + output_token_ids = [1] * output_len + prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 + seq_data = SequenceData(prompt_token_ids_array) + seq_data.output_token_ids = output_token_ids + + return SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=is_prompt, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + encoder_seq_data=encoder_dummy_data.seq_data, + multi_modal_data=encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=encoder_dummy_data. + multi_modal_placeholders, + cross_block_table=cross_block_table) def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: From 0c2759e678e691c8bc7367f86c1e20592b088b0c Mon Sep 17 00:00:00 2001 From: yan ma Date: Fri, 31 Jan 2025 18:39:07 +0800 Subject: [PATCH 7/7] fix rebase Signed-off-by: yan ma --- vllm/worker/hpu_enc_dec_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 66ede172cf694..5813fae391feb 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -482,6 +482,8 @@ def create_dummy_seq_group_metadata(self, image_data = [image_data] assert is_list_of(image_data, Image.Image) text_prompt_len = input_len - 2 - len(image_data) + # for prompt like '<|image|><|image|><|begin_of_text|>...', token + # ids will be '128000 128256 128256 128000 ...' prompt_token_ids = [128000] + [128256] * len(image_data) + [ 128000 ] + [0] * text_prompt_len