From 2450cde48b864388991b067c1bf7d6de233bcfea Mon Sep 17 00:00:00 2001 From: Choiszt Date: Fri, 6 Sep 2024 16:44:52 +0800 Subject: [PATCH] update wrap_unwrap context for zero3-model inference --- lmms_eval/models/llava_onevision.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lmms_eval/models/llava_onevision.py b/lmms_eval/models/llava_onevision.py index 907009b9..1a5b0882 100644 --- a/lmms_eval/models/llava_onevision.py +++ b/lmms_eval/models/llava_onevision.py @@ -24,7 +24,7 @@ from lmms_eval.api.registry import register_model from lmms_eval.models.model_utils.load_video import read_video_pyav from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union - +from trl.models.utils import unwrap_model_for_generation # Suppress warnings warnings.filterwarnings("ignore") @@ -565,12 +565,20 @@ def _collate(x): # TODO: attention to this major generation step... if "image_aspect_ratio" in gen_kwargs.keys(): gen_kwargs.pop("image_aspect_ratio") + + try: with torch.inference_mode(): - cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) + with unwrap_model_for_generation(self._model,self.accelerator) as unwrapped_model: + cont = unwrapped_model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) + + + # cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) + # cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) + self.accelerator.wait_for_everyone() except Exception as e: raise e