diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 2951a4db2e478..9fa673cb984ff 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -4,18 +4,21 @@ import math from array import array from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast import habana_frameworks.torch as htorch import torch from vllm_hpu_extension.ops import batch2block, block2batch from vllm.attention import AttentionMetadata +from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SequenceGroupToSample from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceData, SequenceGroupMetadata, + SequenceOutput) from vllm.utils import is_fake_hpu from vllm.worker.hpu_model_runner import (HpuModelAdapter, HPUModelRunnerBase, ModelInputForHPUWithSamplingMetadata, @@ -397,7 +400,27 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + is_single_step = \ + self.vllm_config.scheduler_config.num_scheduler_steps == 1 + if is_prompt or is_single_step: + self.execute_model(inputs, kv_caches, warmup_mode=True) + else: # decode with multi-step + inputs = dataclasses.replace(inputs, + is_first_multi_step=True, + is_last_step=False) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) + inputs = dataclasses.replace(inputs, + is_first_multi_step=False, + is_last_step=True) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) torch.hpu.synchronize() if profiler: profiler.step() @@ -412,7 +435,7 @@ def create_dummy_seq_group_metadata(self, is_prompt, lora_request=None, temperature=0): - sampling_params = SamplingParams(temperature=0) + sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) cross_block_table: Optional[List[int]] = None encoder_dummy_data \ @@ -516,6 +539,19 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", phase, batch_size, seq_len) + def add_dummy_seq(self, seq_group_metadata_list, is_prompt): + real_batch_size = len(seq_group_metadata_list) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, is_prompt) + batch_size_padding = batch_size_padded - real_batch_size + seq_group_metadata_list = seq_group_metadata_list.copy() + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) + return seq_group_metadata_list + @torch.inference_mode() def execute_model( self, @@ -524,95 +560,264 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, warmup_mode=False, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "num_steps > 1 is not supported in HPUEncoderDecoderModelRunner" - ) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - sampling_metadata = model_input.sampling_metadata - real_batch_size = model_input.real_batch_size - batch_size_padded = model_input.batch_size_padded - assert input_tokens is not None - assert input_positions is not None - assert sampling_metadata is not None - assert attn_metadata is not None - is_prompt = attn_metadata.is_prompt - assert is_prompt is not None - batch_size = input_tokens.size(0) - seq_len = self._seq_len(attn_metadata) - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - self._check_config(batch_size, seq_len, is_prompt, warmup_mode) - - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "kv_caches": kv_caches, - "attn_metadata": self.trim_attn_metadata(attn_metadata), - "intermediate_tensors": intermediate_tensors, - **(model_input.multi_modal_kwargs or {}), - } - if htorch.utils.internal.is_lazy(): - execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) - - htorch.core.mark_step() - if self.is_driver_worker: - model_event_name = ("model_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") + previous_hidden_states: Optional[torch.Tensor] = None, + seqs=None, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + # not first or last multi-step + return [] + # last multi-step + output = self._decode_sampler_outputs( + model_input) if self.is_driver_worker else [] + torch.hpu.synchronize() + if model_input.is_first_multi_step: + # first multi-step + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + sampling_metadata = model_input.sampling_metadata + real_batch_size = model_input.real_batch_size + batch_size_padded = model_input.batch_size_padded + assert input_tokens is not None + assert input_positions is not None + assert sampling_metadata is not None + assert attn_metadata is not None + is_prompt = attn_metadata.is_prompt + assert is_prompt is not None + batch_size = input_tokens.size(0) + seq_len = self._seq_len(attn_metadata) + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": self.trim_attn_metadata(attn_metadata), + "intermediate_tensors": intermediate_tensors, + **(model_input.multi_modal_kwargs or {}), + } + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update( + {"bypass_hpu_graphs": not use_graphs}) + + htorch.core.mark_step() + if self.is_driver_worker: + model_event_name = ("model_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + else: + model_event_name = 'model_executable' + if num_steps > 1: + # in case of multi-step scheduling + # we only want to pythonize in the last step + sampling_metadata.skip_sampler_cpu_output = True + self.model.model.sampler.include_gpu_probs_tensor = True + cache_orig_output_tokens_len: List[Dict] = [] + + def try_revert_dummy_output_tokens(): + if len(cache_orig_output_tokens_len) > 0: + # Reuse the original output token ids length + for i in range(len(cache_orig_output_tokens_len)): + seq_group_metadata = seq_group_metadata_list[i] + for j, data in seq_group_metadata.seq_data.items(): + orig_output_tokens_len = \ + cache_orig_output_tokens_len[i][j] + data.output_token_ids = \ + data.output_token_ids[:orig_output_tokens_len] + + for i in range(num_steps): + if i != 0 and not self.is_driver_worker: + broadcast_data = broadcast_tensor_dict(src=0) + if 'early_exit' in broadcast_data and broadcast_data[ + 'early_exit']: + return [output] if num_steps == 1 else [] + execute_model_kwargs.update({ + "input_ids": + broadcast_data["input_ids"], + "positions": + broadcast_data["positions"], + "attn_metadata": + self.trim_attn_metadata( + broadcast_data["attn_metadata"]) + }) + with self.profiler.record_event('internal', model_event_name): + hidden_states = self.model.forward( + **execute_model_kwargs, + selected_token_indices=sampling_metadata. + selected_token_indices) + + # Compute the logits. + with self.profiler.record_event( + 'internal', + ('compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}')): + if num_steps == 1: + sampling_metadata.selected_token_indices = None + logits = self.model.compute_logits(hidden_states, + sampling_metadata) + htorch.core.mark_step() + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + continue + + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. + with self.profiler.record_event( + 'internal', ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}')): + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + if num_steps > 1: + output = output.sampled_token_ids + self.cached_step_outputs.append( + output.detach().clone()) + htorch.core.mark_step() + if i < num_steps - 1: + if i == 0: + if model_input.async_callback is not None: + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + seq_group_metadata_list = \ + ctx.seq_group_metadata_list + elif seqs is not None: + seq_group_metadata_list = seqs + else: + raise RuntimeError( + "seq_group_metadata_list is uninitialized") + for seq_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): + # Skip empty steps + seq_group_metadata.state.current_step += ( + num_steps - 2) + # Cache the original output token ids + cache_orig_output_tokens_len.append({}) + for j, data in seq_group_metadata.seq_data.items(): + cache_orig_output_tokens_len[seq_idx][j] = \ + len(data.output_token_ids) + seq_group_metadata_list = self.add_dummy_seq( + seq_group_metadata_list, is_prompt=False) + for seq_group_metadata in seq_group_metadata_list: + for data in seq_group_metadata.seq_data.values(): + max_output_len = sampling_metadata.seq_groups[ + 0].sampling_params.max_tokens + if len(data.output_token_ids) < max_output_len - 1: + # add a place holder for prepare_decode + # arbitrary value, this could be any token + dummy_token = (540, ) + data.output_token_ids += (dummy_token) + else: + broadcast_tensor_dict({'early_exit': True}, + src=0) + if num_steps == 1: + return [output] + else: + try_revert_dummy_output_tokens() + return [] + + result = self._prepare_decode(seq_group_metadata_list, + output=output) + execute_model_kwargs.update({ + "input_ids": + result.input_tokens, + "positions": + result.input_positions, + "attn_metadata": + self.trim_attn_metadata(result.attn_metadata) + }) + model_kwargs_broadcast_data = { + "input_ids": result.input_tokens, + "positions": result.input_positions, + "attn_metadata": vars(result.attn_metadata) + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + else: + try_revert_dummy_output_tokens() + + if self.is_driver_worker and self.profiler.enabled: + # Stop recording 'execute_model' event + self.profiler.end() + event_end = self.profiler.get_timestamp_us() + counters = self.profiler_counter_helper.get_counter_dict( + cache_config=self.cache_config, + duration=event_end - self.event_start, + seq_len=seq_len, + batch_size_padded=batch_size_padded, + real_batch_size=real_batch_size, + is_prompt=is_prompt) + self.profiler.record_counter(self.event_start, counters) + if num_steps == 1: + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states + return [output] if self.is_driver_worker else [] + else: + return [] + + return output if type(output) is list else [output] + + def _decode_sampler_outputs(self, model_input): + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = self._make_decode_output( + next_token_ids, model_input.sampling_metadata.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=False) + model_input.async_callback() + + if use_async_out_proc: + return [sampler_outputs[-1]] else: - model_event_name = 'model_executable' - with self.profiler.record_event('internal', model_event_name): - hidden_states = self.model.forward( - **execute_model_kwargs, - selected_token_indices=sampling_metadata.selected_token_indices - ) + return sampler_outputs - # Compute the logits. - with self.profiler.record_event( - 'internal', ('compute_logits_' - f'{"prompt" if is_prompt else "decode"}_bs' - f'{batch_size}_' - f'seq{seq_len}')): - sampling_metadata.selected_token_indices = None - logits = self.model.compute_logits(hidden_states, - sampling_metadata) - htorch.core.mark_step() - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - with self.profiler.record_event( - 'internal', ('sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}')): - output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) - output.outputs = output.outputs[:real_batch_size] - htorch.core.mark_step() - - if self.is_driver_worker and self.profiler.enabled: - # Stop recording 'execute_model' event - self.profiler.end() - event_end = self.profiler.get_timestamp_us() - counters = self.profiler_counter_helper.get_counter_dict( - cache_config=self.cache_config, - duration=event_end - self.event_start, - seq_len=seq_len, - batch_size_padded=batch_size_padded, - real_batch_size=real_batch_size, - is_prompt=is_prompt) - self.profiler.record_counter(self.event_start, counters) - return [output] + def _make_decode_output( + self, + next_token_ids: List[List[int]], + seq_groups: List[SequenceGroupToSample], + ) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx][0] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return SamplerOutput(sampler_outputs) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 0658d17edb0bc..fe283b95a1c23 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1041,11 +1041,14 @@ def _prepare_decode( input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + cross_block_tables: List[List[int]] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[List[int]] = [] lora_prompt_mapping: List[List[int]] = [] lora_requests: Set[LoRARequest] = set() + is_enc_dec_model = self.model_config.is_encoder_decoder if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() lora_ids: List[int] = [] @@ -1060,6 +1063,15 @@ def _prepare_decode( seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id lora_ids.append(lora_id) + if is_enc_dec_model: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_len = ( + seq_group_metadata.encoder_seq_data.get_len() + if seq_group_metadata.encoder_seq_data else 0) + encoder_seq_lens.append(encoder_seq_len) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) @@ -1130,6 +1142,30 @@ def _prepare_decode( assert len(block_list) == len(block_groups) assert len(block_list) == len(block_usage) + if is_enc_dec_model: + last_cross_block_usage = [ + (encoder_seq_len - 1) % self.block_size + 1 + for encoder_seq_len in encoder_seq_lens + ] + cross_block_groups = [[i] * len(bt) + for i, bt in enumerate(cross_block_tables)] + cross_block_usage = [ + [self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(cross_block_tables, last_cross_block_usage) + if bt + ] + cross_block_list = flatten(cross_block_tables) + cross_block_groups = flatten(cross_block_groups) + cross_block_usage = flatten(cross_block_usage) + assert len(cross_block_list) == len(cross_block_groups) + assert len(cross_block_list) == len(cross_block_usage) + + else: + cross_block_list = None + cross_block_groups = None + cross_block_usage = None + encoder_seq_lens_tensor = None + padding_fn = None if self.use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) @@ -1151,6 +1187,50 @@ def _prepare_decode( block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) + if is_enc_dec_model: + if self.use_contiguous_pa: + cross_block_bucket_size = max( + max(cross_block_list) + + 1, len(cross_block_list)) if cross_block_list else 0 + cross_block_bucket_size = \ + self.bucketing_ctx.get_padded_decode_num_blocks( + cross_block_bucket_size) + indices = [None] * cross_block_bucket_size + for i, bid in enumerate(cross_block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) + else: + cross_block_bucket_size = \ + self.bucketing_ctx.get_padded_decode_num_blocks( + len(cross_block_list)) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, cross_block_bucket_size, pad_value) + + real_batch_size = len(seq_group_metadata_list) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, False) + batch_size_padding = batch_size_padded - real_batch_size + if batch_size_padding > 0: + encoder_seq_lens.extend(encoder_seq_lens[0] + for _ in range(batch_size_padding)) + cross_block_list = padding_fn(cross_block_list, _PAD_BLOCK_ID) + cross_block_groups = padding_fn(cross_block_groups, -1) + cross_block_usage = padding_fn(cross_block_usage, 1) + + cross_block_list = torch.tensor(cross_block_list, + dtype=torch.int, + device='cpu') + cross_block_groups = torch.tensor(cross_block_groups, + dtype=torch.int, + device='cpu') + cross_block_usage = torch.tensor(cross_block_usage, + dtype=self.model_config.dtype, + device='cpu') + encoder_seq_lens_tensor = torch.tensor(encoder_seq_lens, + dtype=torch.long, + device='cpu') + block_list = torch.tensor(block_list, dtype=torch.int, device='cpu') block_groups = torch.tensor(block_groups, dtype=torch.int, @@ -1174,6 +1254,15 @@ def _prepare_decode( self.device, non_blocking=True) slot_mapping = slot_mapping.to( # type: ignore self.device, non_blocking=True) + if is_enc_dec_model: + cross_block_list = cross_block_list.to( # type: ignore + self.device, non_blocking=True) + cross_block_groups = cross_block_groups.to( # type: ignore + self.device, non_blocking=True) + cross_block_usage = cross_block_usage.to( # type: ignore + self.device, non_blocking=True) + encoder_seq_lens_tensor = encoder_seq_lens_tensor.to( # type: ignore + self.device, non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, @@ -1186,6 +1275,11 @@ def _prepare_decode( block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + cross_block_list=cross_block_list, + cross_block_groups=cross_block_groups, + cross_block_usage=cross_block_usage, context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0,