From 11a87dc49b0e39093a975b5b7a5172fd2ae804c9 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Fri, 20 Dec 2024 15:00:23 +0200 Subject: [PATCH 1/2] DRAFT: start of implementation of delayed prompts First draft of what where to change, not functional yet. Comments in execute_model mark ideas what and where should be changed --- vllm/config.py | 9 + vllm/engine/arg_utils.py | 9 + vllm/engine/output_processor/interfaces.py | 4 +- vllm/engine/output_processor/single_step.py | 4 +- vllm/model_executor/layers/sampler.py | 32 +++- vllm/sequence.py | 15 +- vllm/worker/hpu_model_runner.py | 188 +++++++++++++++++--- 7 files changed, 221 insertions(+), 40 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4e5c755055f1f..02926254cd843 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1129,6 +1129,9 @@ class SchedulerConfig: # tokens in prefill. use_padding_aware_scheduling: bool = False + # If True, delayed sampling will be enabled for prompts. + enable_delayed_sampling: bool = False + def __post_init__(self) -> None: if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -1204,6 +1207,12 @@ def _verify_args(self) -> None: raise ValueError("Padding-aware scheduling currently " "does not work with chunked prefill ") + if self.enable_delayed_sampling and self.num_lookahead_slots != 1: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be 1 for delayed sampling." + ) + @property def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9f932c6f26eaa..c167b1e3bfbaa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -170,6 +170,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: Optional[bool] = None + enable_delayed_sampling: bool = False guided_decoding_backend: str = 'xgrammar' # Speculative decoding configuration. @@ -714,6 +715,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: const="True", help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.') + parser.add_argument( + '--enable-delayed-sampling', + action='store_true', + help='If set, the sampling will be delayed by 1 step. First ' + 'model request execution (prefill) will return an invalid token ' + 'id that will be discarded. Actual sampling of valid token ids ' + 'starts from second model execution.') parser.add_argument( '--speculative-model', @@ -1178,6 +1186,7 @@ def create_engine_config(self, enable_chunked_prefill=self.enable_chunked_prefill, is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, + enable_delayed_sampling=self.enable_delayed_sampling, num_scheduler_steps=self.num_scheduler_steps, multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 50adaf4e59188..026020b629a5c 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -36,7 +36,9 @@ def create_output_processor( This returns a single-step output processor if num_lookahead_slots is zero, else returns a multi-step output processor. """ - if scheduler_config.num_lookahead_slots == 0: + if (scheduler_config.num_lookahead_slots == 0 + or (scheduler_config.num_lookahead_slots == 1 + and scheduler_config.enable_delayed_sampling)): # Importing here to avoid cycle. from vllm.engine.output_processor.single_step import ( SingleStepOutputProcessor) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index da3185f33dbe9..1c50bffd88417 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -116,7 +116,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, sample = outputs.samples[0] seq = seq_group.first_seq - if not is_async: + # -1 means the output token is not valid (eg. first token if + # delayed sampling is enabled). + if not is_async and sample != -1: seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 8aa6646c5dcea..ab35f991c7576 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -186,6 +186,7 @@ def __init__(self): # containing the sampled token ids and probabilities. This is used by # speculative decoding. self.include_gpu_probs_tensor = False + self.sample_token_positions_only = False self.should_modify_greedy_probs_inplace = False def _init_sampling_tensors( @@ -302,6 +303,7 @@ def forward( sampling_tensors, include_gpu_probs_tensor=self.include_gpu_probs_tensor, modify_greedy_probs=self._should_modify_greedy_probs_inplace, + token_positions_only=self.sample_token_positions_only, ) if self.include_gpu_probs_tensor: @@ -588,6 +590,7 @@ def _apply_min_p( def _greedy_sample( selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, + token_positions_only: bool = False, ) -> SampleResultType: """Run greedy sampling on a given samples. @@ -601,7 +604,9 @@ def _greedy_sample( same as the length of selected_seq_groups. If the corresponding seq_group has do_sample=False, tuple contains ([], []) """ - samples_lst = samples.tolist() + print("DEBUG sync in greedy_sample") + if not token_positions_only: + samples_lst = samples.tolist() sample_idx = 0 results: SampleResultType = [] for seq_group in selected_seq_groups: @@ -614,7 +619,9 @@ def _greedy_sample( assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples_lst[sample_idx]] + next_token_ids = [ + sample_idx if token_positions_only else samples_lst[sample_idx] + ] results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs return results @@ -637,7 +644,8 @@ def _random_sample( seq_group has do_sample=False, tuple contains ([], []) """ # Find the maximum n value of the prompt phase requests. - random_samples = random_samples.cpu() + print("DEBUG sync in random_sample") + random_samples = random_samples.cpu() # zrobic zamiast tego torch.zeros, by nie kopiowac na cpu sample_idx = 0 results: SampleResultType = [] for seq_group in selected_seq_groups: @@ -653,7 +661,7 @@ def _random_sample( # Prompt phase. parent_ids = [0] * sampling_params.n next_token_ids = random_samples[ - sample_idx, :sampling_params.n].tolist() + sample_idx, :sampling_params.n].tolist() # to tez host operacja, czy to potrzebne? else: # Generation phase. parent_ids = list(range(num_parent_seqs)) @@ -798,7 +806,8 @@ def _top_k_top_p_multinomial_with_flashinfer( def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType) -> SampleResultType: + sample_result_args: SampleResultArgsType, + token_positions_only: bool = False) -> SampleResultType: '''This function consumes GPU-side sampler results and computes Pythonized CPU-side sampler results (GPU -> CPU sync.) @@ -836,7 +845,8 @@ def get_pythonized_sample_results( continue (seq_group_id, seq_groups) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) + sample_results = _greedy_sample(seq_groups, greedy_samples, + token_positions_only) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) @@ -858,6 +868,7 @@ def _sample_with_torch( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, + token_positions_only: bool = False, ) -> SampleReturnType: '''Torch-oriented _sample() implementation. @@ -967,13 +978,14 @@ def _sample_with_torch( greedy_samples=greedy_samples, beam_search_logprobs=beam_search_logprobs, sample_results_dict=sample_results_dict) - + #import pdb; pdb.set_trace() if not sampling_metadata.skip_sampler_cpu_output: - # GPU<->CPU sync happens here. + # GPU<->CPU sync happens here, + # unless we're storing only token positions (token_positions_only=True). # This also converts the sampler output to a Python object. # Return Pythonized sampler result & sampled token ids return get_pythonized_sample_results( - maybe_deferred_args), sampled_token_ids_tensor + maybe_deferred_args, token_positions_only), sampled_token_ids_tensor else: # Defer sampler result Pythonization; return deferred # Pythonization args & sampled token ids @@ -990,6 +1002,7 @@ def _sample( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, + token_positions_only: bool, ) -> SampleReturnType: """ Args: @@ -1010,6 +1023,7 @@ def _sample( sampling_tensors, include_gpu_probs_tensor=include_gpu_probs_tensor, modify_greedy_probs=modify_greedy_probs, + token_positions_only=token_positions_only, ) diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..f3c97d8370fae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -175,6 +175,10 @@ class SequenceData(msgspec.Struct, # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None + # it is used in delayed prompt sampling + prev_logits = None + prev_logits_idx = None + @staticmethod def from_prompt_token_counts( *token_counts: Tuple[int, int]) -> "SequenceData": @@ -279,11 +283,11 @@ def mrope_position_delta(self) -> Optional[int]: def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta - def append_token_id(self, token_id: int, logprob: float) -> None: + def append_token_id(self, token_id: int, logprob: Optional[float]) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) - self._cumulative_logprob += logprob + self._cumulative_logprob += logprob if logprob is not None else 0.0 def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) @@ -315,8 +319,6 @@ def get_num_computed_tokens(self) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) # If all tokens are computed, it means it is in decoding phase. if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE @@ -536,9 +538,10 @@ def reset_state_for_recompute(self): def append_token_id(self, token_id: int, logprobs: Dict[int, Logprob]) -> None: - assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.data.append_token_id( + token_id, + logprobs[token_id].logprob if token_id in logprobs else None) def get_len(self) -> int: return self.data.get_len() diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7c3679d40546d..a33414d59b591 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,6 +11,7 @@ import math import os import time +import pdb from array import array from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, @@ -101,7 +102,7 @@ def align_workers(value, op): def setup_profiler(): - schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) + schedule = torch.profiler.schedule(wait=0, warmup=2, active=3, repeat=1) activities = [ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.HPU @@ -484,6 +485,7 @@ class ModelInputForHPU(ModelRunnerInputBase): async_callback: Optional[Callable] = None is_first_multi_step: bool = True is_last_step: bool = True + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -747,6 +749,8 @@ def load_model(self) -> None: names_for_rope = get_names_for_rope(self.model) torch.hpu.synchronize() + # FIXME: Check if running with disable_tensor_cache=True causes + # RuntimeErrors. It was a case in 1.18 delayed sampling. with HabanaMemoryProfiler() as m_wrap: self.model = self._maybe_wrap_in_hpu_graph( self.model, @@ -1049,7 +1053,9 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) - seq_len = seq_data.get_len() + seq_len = ((seq_data.get_num_computed_tokens() + + 1) if self.scheduler_config.enable_delayed_sampling + else seq_data.get_len()) position = seq_len - 1 input_positions.append([position]) @@ -1320,7 +1326,8 @@ def prepare_input_tensors( "num_prefills": num_prefills, "batch_type": batch_type, "seq_lens": seq_lens, - "query_lens": query_lens + "query_lens": query_lens, + "seq_group_metadata_list": seq_group_metadata_list, } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) @@ -1341,7 +1348,8 @@ def prepare_input_tensors( multi_modal_kwargs=multi_modal_kwargs, real_batch_size=real_batch_size, batch_size_padded=batch_size_padded, - lora_ids=lora_ids), \ + lora_ids=lora_ids, + seq_group_metadata_list=seq_group_metadata_list), \ sampling_metadata def _seq_len(self, attn_metadata): @@ -1432,9 +1440,9 @@ def warmup_scenario(self, seq_len, is_prompt, kv_caches, - is_pt_profiler_run=False, + is_pt_profiler_run=True, is_lora_profile_run=False, - temperature=0) -> None: + temperature=1) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" @@ -1492,9 +1500,10 @@ def warmup_scenario(self, torch.hpu.synchronize() profiler = None if is_pt_profiler_run and self.is_driver_worker: + print(f'DEBUG setup profiler & start') profiler = setup_profiler() profiler.start() - for _ in range(times): + for _ in range(5): inputs = self.prepare_model_input(seqs) is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 @@ -1522,6 +1531,8 @@ def warmup_scenario(self, profiler.step() if profiler: profiler.stop() + if profiler: + profiler.cleanup() self.profiler.end() gc.collect() @@ -1642,6 +1653,7 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: if profile := os.environ.get('VLLM_PT_PROFILE', None): + print('DEBUG Starting warmup profiling...') phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' graphs = graph == 't' @@ -2043,6 +2055,7 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step @@ -2100,6 +2113,67 @@ def execute_model( {"bypass_hpu_graphs": not use_graphs}) htorch.core.mark_step() + + # Kamil: I think this might be the place to sample the prompt tokens. + # We are in the first step of MSS and before the loop for decodes. + # This code was copied from delayed sampling, it should be a good starting point + input_ids = None + # Sample the next token based on previous logits if any. + if self.scheduler_config.enable_delayed_sampling \ + and self.is_driver_worker and not is_prompt: + logits_ids_list = [] + logits_tensor = None + logits_tensor_list = [] + if model_input.seq_group_metadata_list is not None: + for seq_group_metadata in model_input.seq_group_metadata_list: + assert len(seq_group_metadata.seq_data) == 1 + for seq_data in seq_group_metadata.seq_data.values(): + if seq_data.prev_logits is not None: + if logits_tensor is None: + logits_tensor = seq_data.prev_logits + if seq_data.prev_logits is logits_tensor: + # accumulate row ids from the same tensor + logits_ids_list.append( + seq_data.prev_logits_idx) + else: + # new logits tensor, + # gather all previously collected rows + logits_tensor_list.append( + logits_tensor[torch.tensor( + logits_ids_list, + device=seq_data.prev_logits.device)]) + logits_ids_list = [seq_data.prev_logits_idx] + logits_tensor = seq_data.prev_logits + else: + # warmup only, TODO add a check + logits_tensor_list.append( + torch.zeros([1, 32000], + dtype=torch.float, + device="hpu")) + if logits_tensor is not None: + logits_tensor_list.append(logits_tensor[torch.tensor( + logits_ids_list, device=seq_data.prev_logits.device)]) + prev_logits = torch.cat(logits_tensor_list, dim=0) + with self.profiler.record_event( + 'internal', f'sample_{"prompt" if is_prompt else "decode"}' + '_bs{batch_size}_seq{seq_len}'): + output = self.model.sample( + logits=prev_logits, + sampling_metadata=sampling_metadata, + ) + #TODO: check why broadcast failed for float tensor use dict instead + model_kwargs = {} + model_kwargs["input_ids"] = output.sampled_token_ids + broadcast_tensor_dict(model_kwargs, src=0) + input_ids = output.sampled_token_ids + elif self.scheduler_config.enable_delayed_sampling and not is_prompt: + model_kwargs = broadcast_tensor_dict(src=0) + input_ids = model_kwargs["input_ids"] + if input_ids is not None: + execute_model_kwargs["input_ids"] = input_ids + htorch.core.mark_step() + ############################## proposed block code end ############################## + if self.is_driver_worker: model_event_name = ("model_" f"{'prompt' if is_prompt else 'decode'}_" @@ -2108,8 +2182,17 @@ def execute_model( f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - if num_steps > 1: - # in case of multi-step scheduling + + ## Kamil: Should we use these two flags in delayed sampling? + # they are used not to do a pythonized call in _sample_with_torch which forces a sync. + # Added this "or" here but original delayed sampling didn't add this flag + # but in sampler code added an if to not do a sync to CPU in the _greedy_sample + # function. Maybe we can reuse this flag instead of if in the _greedy_sample + # TODO: decide if we want to avoid pythonized call or if the sync in sampler + # Added here is_prompt + if num_steps > 1 or (is_prompt and self.scheduler_config.enable_delayed_sampling): + ############################## proposed block code end ############################## + # in case of multi-step scheduling or delayed sampling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True @@ -2152,6 +2235,36 @@ def try_revert_dummy_output_tokens(): lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices)) + # Kamil: Delayed sampling is returning a -1 as an answer during the prompt + # then in decode it samples a real answer. Using -1 we can later in vllm catch it + # and skip that answer. + # TODO: do we want to keep the -1 return from delayed sampling? + # Added here is_prompt + if is_prompt and self.scheduler_config.enable_delayed_sampling \ + and self.is_driver_worker: + # For prompts compose empty output + sampler_output = [] + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + next_token_id, parent_id = -1, 0 + seq_outputs = [] + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + {-1: Logprob(0.0)})) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + sampled_token_probs, logprobs_tensor, sampled_token_ids = ( + None, None, None) + output = SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, + ) + output.outputs = output.outputs[:real_batch_size] + htorch.core.mark_step() + ############################## proposed block code end ############################## + # Compute the logits. with self.profiler.record_event( 'internal', @@ -2163,6 +2276,20 @@ def try_revert_dummy_output_tokens(): sampling_metadata.selected_token_indices = None logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Kamil: delayed sampling uses prev_logits and their idx to be saved + # for later compute. Added here is_prompt + if is_prompt and (self.scheduler_config.enable_delayed_sampling + and model_input.seq_group_metadata_list is not None + and self.is_driver_worker): + for idx, seq_group_metadata in enumerate( + model_input.seq_group_metadata_list): + assert len(seq_group_metadata.seq_data) == 1 + for seq_data in seq_group_metadata.seq_data.values(): + seq_data.prev_logits = logits + seq_data.prev_logits_idx = idx + ############################## proposed block code end ############################## + htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -2170,20 +2297,30 @@ def try_revert_dummy_output_tokens(): if model_input.async_callback is not None: model_input.async_callback() - # Sample the next token. - with self.profiler.record_event( - 'internal', ('sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}')): - output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) - if num_steps > 1: - output = output.sampled_token_ids - self.cached_step_outputs.append( - output.detach().clone()) + + # Kamil: delayed sampling is skiping sampling at the end in + # both decode and sample and does it always in the beginning of + # the next decode. + # TODO: do we want to keep the delayed sampling skip sampling at the end? + # I think we should only do this sampling in the decodes and skip it in the + # prompts in prompt delayed sampling. + # So the prompts will skip the sampling at all and the decodes will do the sampling + # at the end for their's results and at the beginning for prompts's outputs. + if not (self.scheduler_config.enable_delayed_sampling and is_prompt): + with self.profiler.record_event( + 'internal', ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}')): + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + if num_steps > 1: + output = output.sampled_token_ids + self.cached_step_outputs.append( + output.detach().clone()) + ############################## proposed block code end ############################## htorch.core.mark_step() if i < num_steps - 1: if i == 0: @@ -2258,6 +2395,9 @@ def try_revert_dummy_output_tokens(): real_batch_size=real_batch_size, is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) + ## Kamil: here is the part where MSS returns decodes at the end as a list + # in other steps it returns []. The output here if sampler for prompt was run + # during the decode needs to have appended the real output alongside the output from decodes. if num_steps == 1: if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -2268,6 +2408,8 @@ def try_revert_dummy_output_tokens(): return [output] if self.is_driver_worker else [] else: return [] + ############################## proposed block code end ############################## + return output if type(output) is list else [output] From 2fe25bd0bb7cacfde7a897f71cf45ce721cb9fde Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Fri, 20 Dec 2024 14:40:47 +0100 Subject: [PATCH 2/2] Update hpu_model_runner.py --- vllm/worker/hpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a33414d59b591..6c76f62dcace8 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2395,9 +2395,9 @@ def try_revert_dummy_output_tokens(): real_batch_size=real_batch_size, is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) - ## Kamil: here is the part where MSS returns decodes at the end as a list - # in other steps it returns []. The output here if sampler for prompt was run - # during the decode needs to have appended the real output alongside the output from decodes. + ## Kamil: here is the part where MSS returns outputs of decodes at the end as a list in the first step + # and in other steps it returns []. If delayed sampler of prompts was run in the first step + # then we need to add to this list the outputs from sampling of prompts. if num_steps == 1: if self.return_hidden_states: # we only need to pass hidden states of most recent token