From c245ef0b5bdddb6f6e595020a9139a5a8c7bad34 Mon Sep 17 00:00:00 2001 From: Artur Fierka Date: Mon, 13 Jan 2025 09:28:38 +0100 Subject: [PATCH] Fix model OOM issue in llama-405 and mixtral - 2nd attempt (#644) Another approach to fix the OOM issue in loading model. This time instead change specific models code, I updated model weights iterator. Hope this fix will easier to upstream. --- vllm/model_executor/model_loader/loader.py | 11 +++++++++++ vllm/model_executor/models/llama.py | 2 -- vllm/model_executor/models/mixtral.py | 3 --- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b9866738d03e9..4c82568d5a213 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -321,6 +321,17 @@ def _xla_weights_iterator(iterator: Generator): weights_iterator = _xla_weights_iterator(weights_iterator) + if current_platform.is_hpu(): + + import habana_frameworks.torch.core as htcore + + def _hpu_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + htcore.mark_step() + + weights_iterator = _hpu_weights_iterator(weights_iterator) + # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6461a80cef331..650f8483b76bb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -435,8 +435,6 @@ def load_weights(self, weights: Iterable[Tuple[str, default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) - if is_hpu: - torch.hpu.synchronize() return loaded_params # If this function is called, it should always initialize KV cache scale diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3688233c19a81..a5b364fe5ec85 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -483,6 +482,4 @@ def load_weights(self, weights: Iterable[Tuple[str, default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) - if current_platform.is_hpu(): - torch.hpu.synchronize() return loaded_params