Skip to content

Commit

Permalink
No logits saving when sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
mfylcek committed Jan 27, 2025
1 parent 74c87c5 commit 346ff80
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,9 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size

#! TODO: batch size padding breakes accuracy
batch_size_padding =0

seq_group_metadata_list = seq_group_metadata_list.copy()

if batch_size_padding > 0:
Expand Down Expand Up @@ -2161,6 +2164,7 @@ def execute_model(

htorch.core.mark_step()

#breakpoint()
input_ids = None
# Delayed sampling
# Sample the next token based on previous logits if any.
Expand Down Expand Up @@ -2340,8 +2344,8 @@ def try_revert_dummy_output_tokens():
for idx, seq_group_metadata in enumerate(model_input.seq_group_metadata_list):
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
seq_data.prev_logits = logits
seq_data.prev_logits_idx = idx
seq_data.prev_logits = logits if not should_sample else None
seq_data.prev_logits_idx = idx if not should_sample else None

htorch.core.mark_step()
# Only perform sampling in the driver worker.
Expand Down

0 comments on commit 346ff80

Please sign in to comment.