From 346ff80dd652d53d02861be8054bff547fd27d7c Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Mon, 27 Jan 2025 11:50:55 +0200 Subject: [PATCH] No logits saving when sampling --- vllm/worker/hpu_model_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index bb128fb327f36..2c23a540a1ac8 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -803,6 +803,9 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): real_batch_size, is_prompt) batch_size_padding = batch_size_padded - real_batch_size + #! TODO: batch size padding breakes accuracy + batch_size_padding =0 + seq_group_metadata_list = seq_group_metadata_list.copy() if batch_size_padding > 0: @@ -2161,6 +2164,7 @@ def execute_model( htorch.core.mark_step() + #breakpoint() input_ids = None # Delayed sampling # Sample the next token based on previous logits if any. @@ -2340,8 +2344,8 @@ def try_revert_dummy_output_tokens(): for idx, seq_group_metadata in enumerate(model_input.seq_group_metadata_list): assert len(seq_group_metadata.seq_data) == 1 for seq_data in seq_group_metadata.seq_data.values(): - seq_data.prev_logits = logits - seq_data.prev_logits_idx = idx + seq_data.prev_logits = logits if not should_sample else None + seq_data.prev_logits_idx = idx if not should_sample else None htorch.core.mark_step() # Only perform sampling in the driver worker.