Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Ross #718

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@
'valley_eagle': partial(ValleyEagleChat, model_path='bytedance-research/Valley-Eagle-7B'),
}

ross_series = {
'ross-qwen2-7b': partial(Ross, model_path='HaochenWang/ross-qwen2-7b'),
}

supported_VLM = {}

model_groups = [
Expand All @@ -423,7 +427,7 @@
mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series,
slime_series, eagle_series, moondream_series, llama_series, molmo_series,
kosmos_series, points_series, nvlm_series, vintern_series, h2ovl_series, aria_series,
smolvlm_series, sail_series, valley_series, vita_series
smolvlm_series, sail_series, valley_series, vita_series, ross_series
]

for grp in model_groups:
Expand Down
1 change: 1 addition & 0 deletions vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@
from .smolvlm import SmolVLM
from .sail_vl import SailVL
from .valley import ValleyEagleChat
from .ross import Ross
160 changes: 160 additions & 0 deletions vlmeval/vlm/ross.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import torch
from PIL import Image
from abc import abstractproperty
import sys
import copy
import os.path as osp
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE


class Ross(BaseModel):

INSTALL_REQ = True
INTERLEAVE = True

def __init__(self,
model_path='HaochenWang/ross-qwen2-7b',
**kwargs):
from ross.model.builder import load_pretrained_model
from ross.mm_utils import get_model_name_from_path

assert osp.exists(model_path) or splitlen(model_path) == 2

model_name = get_model_name_from_path(model_path)

self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=model_name,
device='cpu',
device_map='cpu',
torch_dtype=torch.float16,
)
self.model.get_vision_tower().load_model()
self.model.eval()
self.model.cuda()

if 'Qwen2' in model_path:
self.conv_mode = 'v1_qwen2'
elif 'llama3' in model_path.lower():
self.conv_mode = 'llama3'
else:
self.conv_mode = 'llava_v1'

kwargs_default = dict(
do_sample=False,
temperature=0,
max_new_tokens=512,
top_p=None,
num_beams=1,
use_cache=True,
) # noqa E501
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

def use_custom_prompt(self, dataset: str) -> bool:
dataset_type = DATASET_TYPE(dataset, default=None)
if dataset_type == 'MCQ':
return True
if dataset_type == 'Y/N' and dataset in {'HallusionBench'}:
return True
return False

def build_prompt(self, line, dataset: str) -> list[dict[str, str]]:
dataset_type = DATASET_TYPE(dataset, default=None)
if dataset_type == 'MCQ':
return self._build_mcq_prompt(line, dataset)
if dataset_type == 'Y/N':
return self._build_yorn_prompt(line, dataset)
raise ValueError(f'Unsupported dataset: {dataset}')

def _build_yorn_prompt(self, line, dataset: str) -> list[dict[str, str]]:
YORN_PROMPT = "\nAnswer the question using a single word or phrase."

tgt_path = self.dump_image(line, dataset)
question = line['question']
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=question))
assert msgs[-1]['type'] == 'text'
msgs[-1]['value'] += YORN_PROMPT
return msgs

def _build_mcq_prompt(self, line, dataset: str) -> list[dict[str, str]]:
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)

question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question

options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question

if len(options):
prompt += (
'\n请直接回答选项字母。' if cn_string(prompt) else
"\nAnswer with the option's letter from the given choices directly."
)
else:
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'

message = [dict(type='image', value=s) for s in tgt_path]
message.append(dict(type='text', value=prompt))
return message

def generate_inner(self, message, dataset=None):
from ross.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
from ross.constants import (
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
from ross.conversation import conv_templates, SeparatorStyle

# Support interleave text and image
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], 'PLACEHOLDER')
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

content, images = '', []
for msg in message:
if msg['type'] == 'text':
content += msg['value']
elif msg['type'] == 'image':
if self.model.config.mm_use_im_start_end:
content += DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n'
else:
content += DEFAULT_IMAGE_TOKEN + '\n'
images.append(msg['value'])

images = [Image.open(s).convert('RGB') for s in images]
args = abstractproperty()
args.image_aspect_ratio = 'pad'
image_tensor = process_images(images, self.image_processor, args).to('cuda', dtype=torch.float16)
prompt = prompt.replace('PLACEHOLDER', content)

input_ids = tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str] + ["<|im_end|>"] if stop_str == "<|im_start|>" else [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

with torch.inference_mode():
output_ids = self.model.generate(
input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs)

output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output