Skip to content

Commit

Permalink
update features for lmms-eval mid checkpoint and one customized versi…
Browse files Browse the repository at this point in the history
…on of llava_onevision
  • Loading branch information
choiszt committed Sep 4, 2024
1 parent 0c14015 commit ebdfe16
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 22 deletions.
4 changes: 2 additions & 2 deletions lmms_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def generate_until(self, requests) -> List[str]:
pass

@classmethod
def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T:
def create_from_arg_string(cls: Type[T],tuned_model,tuned_model_tokenizer, arg_string: str, additional_config: Optional[dict] = None) -> T:
"""
Creates an instance of the LMM class using the given argument string and additional config.
Expand All @@ -88,7 +88,7 @@ def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Opt
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
return cls(tuned_model=tuned_model,tuned_model_tokenizer=tuned_model_tokenizer,**args, **args2)

@property
def rank(self):
Expand Down
5 changes: 4 additions & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
@positional_deprecated
def simple_evaluate(
tuned_model,
tuned_model_tokenizer,
model,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
Expand Down Expand Up @@ -163,6 +164,8 @@ def simple_evaluate(

ModelClass = get_model(model)
lm = ModelClass.create_from_arg_string(
tuned_model,
tuned_model_tokenizer,
model_args,
{
"batch_size": batch_size,
Expand Down Expand Up @@ -251,7 +254,7 @@ def _adjust_config(task_dict):
verbosity=verbosity,
cli_args=cli_args,
)
print(lm.rank)
# print(lm.rank)
if lm.rank == 0:
if isinstance(model, str):
model_name = model
Expand Down
53 changes: 35 additions & 18 deletions lmms_eval/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.load_video import read_video_pyav
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

# Suppress warnings
warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(
mm_spatial_pool_mode: Optional[str] = "bilinear",
token_strategy: Optional[str] = "single", # could be "single" or "multiple", "multiple" denotes adding multiple <image> tokens for each frame
video_decode_backend: str = "decord",
tuned_model:Optional[Any]=None,
tuned_model_tokenizer:Optional[Any]=None,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -139,13 +142,29 @@ def __init__(
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor

llava_model_args["overwrite_config"] = overwrite_config
try:
# Try to load the model with the multimodal argument
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
except TypeError:
# for older versions of LLaVA that don't have multimodal argument
llava_model_args.pop("multimodal", None)
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)

if tuned_model:
self._model=tuned_model
self._image_processor=tuned_model.get_vision_tower().image_processor
self._tokenizer=tuned_model_tokenizer
if hasattr(tuned_model.config, "max_sequence_length"):
context_len = tuned_model.config.max_sequence_length
elif hasattr(tuned_model.config, "max_position_embeddings"):
context_len = tuned_model.config.max_position_embeddings
elif hasattr(tuned_model.config, "tokenizer_model_max_length"):
context_len = tuned_model.config.tokenizer_model_max_length
else:
context_len = 2048
self._max_length=context_len
else:

try:
# Try to load the model with the multimodal argument
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
except TypeError:
# for older versions of LLaVA that don't have multimodal argument
llava_model_args.pop("multimodal", None)
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)

self._config = self._model.config
self.model.eval()
Expand Down Expand Up @@ -448,12 +467,11 @@ def _collate(x):
placeholder_count = 1

elif type(visual[0]) == PIL.Image.Image: # For image, multi-image tasks
# image_tensor = process_images(visual, self._image_processor, self._config)
# if type(image_tensor) is list:
# image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
# else:
# image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
image_tensor=None
image_tensor = process_images(visual, self._image_processor, self._config)
if type(image_tensor) is list:
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
else:
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)

task_type = "image"
placeholder_count = len(visual) if isinstance(visual, list) else 1
Expand Down Expand Up @@ -548,12 +566,11 @@ def _collate(x):
if "image_aspect_ratio" in gen_kwargs.keys():
gen_kwargs.pop("image_aspect_ratio")
try:
# with torch.inference_mode():
# cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
# # cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
with torch.inference_mode():
cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
# cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)

# text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
text_outputs="hi"
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
except Exception as e:
raise e

Expand Down
1 change: 0 additions & 1 deletion lmms_eval/tasks/mme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def mme_aggregate_results(results):
for category, question2scores in category2score.items():
total_score = 0
for question_id, scores in question2scores.items():
print(score)
assert len(scores) == 2, "MME only supports pairwise evaluation"
acc = sum(scores) / len(scores) * 100.0
acc_plus = (sum(scores) == 2) * 100.0
Expand Down

0 comments on commit ebdfe16

Please sign in to comment.