Skip to content

Commit

Permalink
Mypy/formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mfylcek committed Jan 13, 2025
1 parent 4530295 commit 99ca01c
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,27 @@ def load_model(self) -> None:
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)

def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size

seq_group_metadata_list = seq_group_metadata_list.copy()

if batch_size_padding > 0:
sampling_temperatures = [
seq_group_metadata.sampling_params.temperature
for seq_group_metadata in seq_group_metadata_list
]
temperature = 0 if 0 in sampling_temperatures else 1.0

dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt, temperature=temperature)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded

def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
Expand Down Expand Up @@ -1224,8 +1245,8 @@ def prepare_input_tensors(
base_event_name = 'prompt' if is_prompt else 'decode'
self.profiler.start('internal', base_event_name)

seq_group_metadata_list, real_batch_size, batch_size_padded = self.add_dummy_seq(
seq_group_metadata_list, is_prompt)
seq_group_metadata_list, real_batch_size, batch_size_padded = (
self._add_dummy_seq(seq_group_metadata_list, is_prompt))

prefill_reqs = []
decode_reqs = []
Expand Down Expand Up @@ -2032,27 +2053,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size

seq_group_metadata_list = seq_group_metadata_list.copy()

if batch_size_padding > 0:
sampling_temperatures = [
seq_group_metadata.sampling_params.temperature
for seq_group_metadata in seq_group_metadata_list
]
temperature = 0 if 0 in sampling_temperatures else 1.0

dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt, temperature=temperature)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2228,7 +2228,7 @@ def try_revert_dummy_output_tokens():
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[seq_idx][j] = \
len(data.output_token_ids)
seq_group_metadata_list, _, _ = self.add_dummy_seq(
seq_group_metadata_list, _, _ = self._add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
Expand Down

0 comments on commit 99ca01c

Please sign in to comment.