Skip to content

Commit

Permalink
Random sampler warmup (#506)
Browse files Browse the repository at this point in the history
Random sampler warmup
  • Loading branch information
mfylcek authored Nov 20, 2024
2 parents 8c3f56a + e24a5af commit 6338608
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,8 +1425,9 @@ def create_dummy_seq_group_metadata(self,
group_id,
seq_len,
is_prompt,
lora_request=None):
sampling_params = SamplingParams(temperature=0)
lora_request=None,
temperature=0):
sampling_params = SamplingParams(temperature=temperature)
num_blocks = math.ceil(seq_len / self.block_size)
seq_len = max(seq_len, 1)
if is_prompt:
Expand Down Expand Up @@ -1466,7 +1467,8 @@ def warmup_scenario(self,
is_prompt,
kv_caches,
is_pt_profiler_run=False,
is_lora_profile_run=False) -> None:
is_lora_profile_run=False,
temperature=0) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
scenario_name = ("warmup_"
f"{'prompt' if is_prompt else 'decode'}_"
Expand Down Expand Up @@ -1505,8 +1507,8 @@ def warmup_scenario(self,
seq_len,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None)
for i in range(batch_size)
if dummy_lora_requests_per_seq else None,
temperature=temperature) for i in range(batch_size)
]
else:
# FIXME: seq_len is actually number of blocks
Expand All @@ -1518,8 +1520,8 @@ def warmup_scenario(self,
b * self.block_size - 1,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None)
for i, b in enumerate(blocks)
if dummy_lora_requests_per_seq else None,
temperature=temperature) for i, b in enumerate(blocks)
]
torch.hpu.synchronize()
profiler = None
Expand Down Expand Up @@ -1629,6 +1631,7 @@ def warmup_graphs(self,
f'Unsupported graph allocation strategy: {strategy}')
buckets = list(sorted(buckets, key=ordering))
captured_all = True
warmed_random_sampler_bs: Set[int] = set()
for idx, (batch_size, seq_len) in enumerate(buckets):
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len if is_prompt else batch_size
Expand All @@ -1642,7 +1645,13 @@ def warmup_graphs(self,
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
self.warmup_scenario(batch_size,
seq_len,
is_prompt,
kv_caches,
temperature=1.0 if batch_size
not in warmed_random_sampler_bs else 0)
warmed_random_sampler_bs.add(batch_size)
used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX)
available_mem -= used_mem
Expand Down

0 comments on commit 6338608

Please sign in to comment.