Skip to content

Commit

Permalink
update wrap_unwrap context for zero3-model inference
Browse files Browse the repository at this point in the history
  • Loading branch information
choiszt committed Sep 6, 2024
1 parent ebdfe16 commit 2450cde
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lmms_eval/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.load_video import read_video_pyav
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from trl.models.utils import unwrap_model_for_generation
# Suppress warnings
warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -565,12 +565,20 @@ def _collate(x):
# TODO: attention to this major generation step...
if "image_aspect_ratio" in gen_kwargs.keys():
gen_kwargs.pop("image_aspect_ratio")


try:
with torch.inference_mode():
cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
with unwrap_model_for_generation(self._model,self.accelerator) as unwrapped_model:
cont = unwrapped_model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)


# cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)

# cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)

text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
self.accelerator.wait_for_everyone()
except Exception as e:
raise e

Expand Down

0 comments on commit 2450cde

Please sign in to comment.