Skip to content

Commit

Permalink
Setting enough cache_size_limit for torch.compile warmup (#238)
Browse files Browse the repository at this point in the history
Fix the issue that warmup sometimes doesn't work because the default
cache_size_limit is only 8 .

---------

Signed-off-by: zehao-intel <[email protected]>
Co-authored-by: Andrzej Kotłowski <[email protected]>
  • Loading branch information
zehao-intel and anko-intel authored Sep 25, 2024
1 parent 8c6dcae commit cef2f54
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
len(self.decode_buckets),
list(sorted(self.decode_buckets)))

if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
cache_size_limit = len(self.prompt_buckets) + len(
self.decode_buckets) + 1
torch._dynamo.config.cache_size_limit = max(
cache_size_limit, torch._dynamo.config.cache_size_limit)
# Multiply by 8 to follow the original default ratio between
# the cache_size_limit and accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = max(
cache_size_limit * 8,
torch._dynamo.config.accumulated_cache_size_limit)

start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()

Expand Down

0 comments on commit cef2f54

Please sign in to comment.