Skip to content

Commit

Permalink
Fix for random sampler recompilations for incomplete batches (#663)
Browse files Browse the repository at this point in the history
Changes the sampler used by dummy sequences to greedy if any
sequence is using it. Prevents sampler recompilations.
  • Loading branch information
mfylcek authored Jan 22, 2025
1 parent 1df1c2c commit e977f2a
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,25 @@ def load_model(self) -> None:
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)

def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size

seq_group_metadata_list = seq_group_metadata_list.copy()

if batch_size_padding > 0:
has_greedy_samples = any(
seq_group_metadata.sampling_params.temperature == 0.0
for seq_group_metadata in seq_group_metadata_list)
temperature = 0.0 if has_greedy_samples else 1.0
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt, temperature=temperature)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded

def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
Expand Down Expand Up @@ -1256,16 +1275,8 @@ def prepare_input_tensors(
base_event_name = 'prompt' if is_prompt else 'decode'
self.profiler.start('internal', base_event_name)

real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
seq_group_metadata_list, real_batch_size, batch_size_padded = (
self._add_dummy_seq(seq_group_metadata_list, is_prompt))

prefill_reqs = []
decode_reqs = []
Expand Down Expand Up @@ -2066,19 +2077,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2255,7 +2253,7 @@ def try_revert_dummy_output_tokens():
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[seq_idx][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, _, _ = self._add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
Expand Down

0 comments on commit e977f2a

Please sign in to comment.