Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for random sampler recompilations for incomplete batches #663

Merged
merged 14 commits into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,25 @@ def load_model(self) -> None:
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)

def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size

seq_group_metadata_list = seq_group_metadata_list.copy()

if batch_size_padding > 0:
has_greedy_samples = any(
seq_group_metadata.sampling_params.temperature == 0.0
for seq_group_metadata in seq_group_metadata_list)
temperature = 0.0 if has_greedy_samples else 1.0
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt, temperature=temperature)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded

def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
Expand Down Expand Up @@ -1256,16 +1275,8 @@ def prepare_input_tensors(
base_event_name = 'prompt' if is_prompt else 'decode'
self.profiler.start('internal', base_event_name)

real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
seq_group_metadata_list, real_batch_size, batch_size_padded = (
self._add_dummy_seq(seq_group_metadata_list, is_prompt))

prefill_reqs = []
decode_reqs = []
Expand Down Expand Up @@ -2066,19 +2077,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2255,7 +2253,7 @@ def try_revert_dummy_output_tokens():
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[seq_idx][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, _, _ = self._add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
Expand Down
Loading