From 0e6e09d7158d2ff4775fa9429610121d87eea5cd Mon Sep 17 00:00:00 2001 From: Hongseok Oh <97136787+abzb1@users.noreply.github.com> Date: Sun, 22 Sep 2024 18:49:54 +0900 Subject: [PATCH] [Model] support Qwen2 VL (#268) * add qwen2-vl * qwen2 vl (black isort) * qwen2 vl black * black * without qwen vl utils and temp images * black * isort * qwen2 vl batch generate * remove unused import * remove unreferenced --- lmms_eval/models/__init__.py | 1 + lmms_eval/models/qwen2_vl.py | 243 +++++++++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+) create mode 100755 lmms_eval/models/qwen2_vl.py diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index bfbb2260..bd93825d 100755 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -35,6 +35,7 @@ "mplug_owl_video": "mplug_Owl", "phi3v": "Phi3v", "qwen_vl": "Qwen_VL", + "qwen2_vl": "Qwen2_VL", "qwen_vl_api": "Qwen_VL_API", "reka": "Reka", "srt_api": "SRT_API", diff --git a/lmms_eval/models/qwen2_vl.py b/lmms_eval/models/qwen2_vl.py new file mode 100755 index 00000000..40410d44 --- /dev/null +++ b/lmms_eval/models/qwen2_vl.py @@ -0,0 +1,243 @@ +from typing import List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, DistributedType +from loguru import logger as eval_logger +from tqdm import tqdm +from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + + +@register_model("qwen2_vl") +class Qwen2_VL(lmms): + """ + Qwen2_VL Model + "https://github.com/QwenLM/Qwen2-VL" + """ + + def __init__( + self, + pretrained: str = "Qwen/Qwen2-VL-7B-Instruct", + device: Optional[str] = "cuda", + device_map: Optional[str] = "cuda", + batch_size: Optional[Union[int, str]] = 1, + use_cache=True, + use_flash_attention_2: Optional[bool] = True, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + if use_flash_attention_2: + self._model = Qwen2VLForConditionalGeneration.from_pretrained( + pretrained, + torch_dtype="auto", + device_map=self.device_map, + attn_implementation="flash_attention_2", + ).eval() + else: + self._model = Qwen2VLForConditionalGeneration.from_pretrained(pretrained, torch_dtype="auto", device_map=self.device_map).eval() + self.processor = AutoProcessor.from_pretrained(pretrained) + self._tokenizer = AutoTokenizer.from_pretrained(pretrained) + + self._config = self.model.config + self.batch_size_per_gpu = int(batch_size) + self.use_cache = use_cache + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [ + DistributedType.FSDP, + DistributedType.MULTI_GPU, + ], "Unsupported distributed type provided. Only DDP and FSDP are supported." + if accelerator.distributed_type == DistributedType.FSDP: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._word_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + raise NotImplementedError("Loglikelihood is not implemented for Qwen2_VL") + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tokenizer.encode(x[0]) + return -len(toks), x[0] + + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) + + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tokenizer.decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + + if isinstance(contexts, tuple): + contexts = list(contexts) + + for i in range(len(contexts)): + if "" in contexts[i]: + contexts[i] = contexts[i].replace("", "") + + messages = [] + + if len(visuals) == 0: + for context in contexts: + message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "text", "text": context}]}] + messages.append(message) + else: + for _, context in zip(visuals, contexts): + message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": context}]}] + messages.append(message) + + texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + inputs = self.processor(text=texts, images=[visuals], padding=True, return_tensors="pt") + + if self.device_map == "auto": + inputs = inputs.to("cuda") + else: + inputs = inputs.to(self.device) + + # preconfigure gen_kwargs with defaults + if "image_sizes" not in gen_kwargs: + try: + gen_kwargs["image_sizes"] = [visuals[0].size] + except: + gen_kwargs["image_sizes"] = None + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 128 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + pad_token_id = self.tokenizer.pad_token_id + + cont = self.model.generate( + **inputs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=pad_token_id, + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + use_cache=self.use_cache, + # kwargs=gen_kwargs + ) + + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)] + answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for i, ans in enumerate(answers): + for term in until: + if len(term) > 0: + ans = ans.split(term)[0] + answers[i] = ans + + for ans, context in zip(answers, contexts): + res.append(ans) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res