Skip to content

Commit

Permalink
Fix model OOM issue in llama-405 and mixtral - 2nd attempt (#644)
Browse files Browse the repository at this point in the history
Another approach to fix the OOM issue in loading model. This time
instead change specific models code, I updated model weights iterator.
Hope this fix will easier to upstream.
  • Loading branch information
afierka-intel authored Jan 13, 2025
1 parent c83289e commit c245ef0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
11 changes: 11 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def _xla_weights_iterator(iterator: Generator):

weights_iterator = _xla_weights_iterator(weights_iterator)

if current_platform.is_hpu():

import habana_frameworks.torch.core as htcore

def _hpu_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
htcore.mark_step()

weights_iterator = _hpu_weights_iterator(weights_iterator)

# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if is_hpu:
torch.hpu.synchronize()
return loaded_params

# If this function is called, it should always initialize KV cache scale
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -483,6 +482,4 @@ def load_weights(self, weights: Iterable[Tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if current_platform.is_hpu():
torch.hpu.synchronize()
return loaded_params

0 comments on commit c245ef0

Please sign in to comment.