Skip to content

Commit

Permalink
[Model] support Qwen2 VL (#268)
Browse files Browse the repository at this point in the history
* add qwen2-vl

* qwen2 vl (black isort)

* qwen2 vl black

* black

* without qwen vl utils and temp images

* black

* isort

* qwen2 vl batch generate

* remove unused import

* remove unreferenced
  • Loading branch information
abzb1 authored and KairuiHu committed Sep 23, 2024
1 parent f92fd71 commit 0e6e09d
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 0 deletions.
1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"mplug_owl_video": "mplug_Owl",
"phi3v": "Phi3v",
"qwen_vl": "Qwen_VL",
"qwen2_vl": "Qwen2_VL",
"qwen_vl_api": "Qwen_VL_API",
"reka": "Reka",
"srt_api": "SRT_API",
Expand Down
243 changes: 243 additions & 0 deletions lmms_eval/models/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import List, Optional, Tuple, Union

import torch
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from tqdm import tqdm
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model


@register_model("qwen2_vl")
class Qwen2_VL(lmms):
"""
Qwen2_VL Model
"https://github.com/QwenLM/Qwen2-VL"
"""

def __init__(
self,
pretrained: str = "Qwen/Qwen2-VL-7B-Instruct",
device: Optional[str] = "cuda",
device_map: Optional[str] = "cuda",
batch_size: Optional[Union[int, str]] = 1,
use_cache=True,
use_flash_attention_2: Optional[bool] = True,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"

if use_flash_attention_2:
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
pretrained,
torch_dtype="auto",
device_map=self.device_map,
attn_implementation="flash_attention_2",
).eval()
else:
self._model = Qwen2VLForConditionalGeneration.from_pretrained(pretrained, torch_dtype="auto", device_map=self.device_map).eval()
self.processor = AutoProcessor.from_pretrained(pretrained)
self._tokenizer = AutoTokenizer.from_pretrained(pretrained)

self._config = self.model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self._rank = 0
self._word_size = 1

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
return self.tokenizer.eos_token_id

@property
def max_length(self):
return self._max_length

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for Qwen2_VL")

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tokenizer.encode(x[0])
return -len(toks), x[0]

pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)

gen_kwargs = all_gen_kwargs[0]

# Set default values for until and max_new_tokens
until = [self.tokenizer.decode(self.eot_token_id)]

# Update values from gen_kwargs if present
if "until" in gen_kwargs:
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")

if isinstance(contexts, tuple):
contexts = list(contexts)

for i in range(len(contexts)):
if "<image>" in contexts[i]:
contexts[i] = contexts[i].replace("<image>", "")

messages = []

if len(visuals) == 0:
for context in contexts:
message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "text", "text": context}]}]
messages.append(message)
else:
for _, context in zip(visuals, contexts):
message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": context}]}]
messages.append(message)

texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages]
inputs = self.processor(text=texts, images=[visuals], padding=True, return_tensors="pt")

if self.device_map == "auto":
inputs = inputs.to("cuda")
else:
inputs = inputs.to(self.device)

# preconfigure gen_kwargs with defaults
if "image_sizes" not in gen_kwargs:
try:
gen_kwargs["image_sizes"] = [visuals[0].size]
except:
gen_kwargs["image_sizes"] = None
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 128
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1

pad_token_id = self.tokenizer.pad_token_id

cont = self.model.generate(
**inputs,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=pad_token_id,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
# kwargs=gen_kwargs
)

generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for i, ans in enumerate(answers):
for term in until:
if len(term) > 0:
ans = ans.split(term)[0]
answers[i] = ans

for ans, context in zip(answers, contexts):
res.append(ans)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)

pbar.close()
return res

0 comments on commit 0e6e09d

Please sign in to comment.