Skip to content

Commit

Permalink
BS padding fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mfylcek committed Jan 27, 2025
1 parent 346ff80 commit c3cfca9
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,6 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size

#! TODO: batch size padding breakes accuracy
batch_size_padding =0

seq_group_metadata_list = seq_group_metadata_list.copy()

if batch_size_padding > 0:
Expand Down Expand Up @@ -1485,6 +1482,7 @@ def create_dummy_seq_group_metadata(self,
output_token_ids = [1] * output_len
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
seq_data = SequenceData(prompt_token_ids_array)

seq_data.output_token_ids = output_token_ids
return SequenceGroupMetadata(request_id=str(group_id),
is_prompt=(output_len == 0),
Expand Down Expand Up @@ -2164,7 +2162,6 @@ def execute_model(

htorch.core.mark_step()

#breakpoint()
input_ids = None
# Delayed sampling
# Sample the next token based on previous logits if any.
Expand All @@ -2175,33 +2172,34 @@ def execute_model(
logits_tensor = None
logits_tensor_list = []
if model_input.seq_group_metadata_list is not None:
for seq_group_metadata in model_input.seq_group_metadata_list:
for i, seq_group_metadata in enumerate(model_input.seq_group_metadata_list):
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
if seq_data.prev_logits is not None:
if logits_tensor is None:
logits_tensor = seq_data.prev_logits
if seq_data.prev_logits is logits_tensor:
logits_ids_list.append(
seq_data.prev_logits_idx)
else:
logits_tensor_list.append(
logits_tensor[torch.tensor(
logits_ids_list,
device=seq_data.prev_logits.device)])

logits_ids_list = [seq_data.prev_logits_idx]
logits_tensor = seq_data.prev_logits
if seq_data.prev_logits is None:
# Padded sequences and warmup
#TODO: Add some sort of check based on metadata(?)
seq_data.prev_logits = torch.zeros([1, 32000],
dtype=torch.float,
device="hpu")
seq_data.prev_logits_idx = i
if logits_tensor is None:
logits_tensor = seq_data.prev_logits
if seq_data.prev_logits is logits_tensor:
logits_ids_list.append(
seq_data.prev_logits_idx)
else:
# warmup only, TODO add a check
logits_tensor_list.append(
torch.zeros([1, 32000],
dtype=torch.float,
device="hpu"))
logits_tensor[torch.tensor(
logits_ids_list,
device=seq_data.prev_logits.device)])

logits_ids_list = [seq_data.prev_logits_idx]
logits_tensor = seq_data.prev_logits

if logits_tensor is not None:
logits_tensor_list.append(logits_tensor[torch.tensor(
logits_ids_list, device=logits_tensor.device)])

prev_logits = torch.cat(logits_tensor_list, dim=0)

# Sample next token - delayed sampling
Expand Down

0 comments on commit c3cfca9

Please sign in to comment.