diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index 0c666145..9e11028e 100644 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -211,7 +211,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: except Exception as e: traceback.print_exc() eval_logger.error(f"Error during evaluation: {e}") - traceback.print_exc() + # traceback.print_exc() results_list.append(None) for args, results in zip(args_list, results_list): diff --git a/lmms_eval/api/instance.py b/lmms_eval/api/instance.py index 41875358..67f00fde 100644 --- a/lmms_eval/api/instance.py +++ b/lmms_eval/api/instance.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Literal, Tuple +from typing import Literal, Tuple, Iterable, Callable @dataclass diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index f77065e8..cecc7219 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -1,5 +1,274 @@ +import datasets +from typing import Callable, Iterable, Optional, List +from abc import ABC, abstractmethod + + +class LazyLoadedImages(object): + def __init__(self, data_frame, index, doc_to_visual: Callable, image_tokens=""): + self.data_frame: datasets.Dataset = data_frame + self.index = index + self.image_lens = None + self.images = None + self.doc_to_visual = doc_to_visual + self.image_tokens = image_tokens + + def get_images(self, lazy_save=False): + if self.images is not None: + return self.images + images = self.doc_to_visual(self.data_frame[self.index]) + self.image_lens = len(images) + if lazy_save: + self.images = images + return images + + def get_num_images(self, lazy_save=False): + if self.image_lens is None: + images = self.get_images(self.doc_to_visual) + if lazy_save: + self.images = images + self.image_lens = len(images) + return self.image_lens + + def clear(self, clear_all=False): + self.images = None + if clear_all: + self.image_lens = None + + def get_text(self, lazy: bool = True): + if lazy: + return self.image_tokens + else: + return " ".join([self.image_tokens] * self.get_num_images()) + + +class QAPairs(object): + def __init__( + self, + data_frame, + index, + *, + doc=None, + include_answer: bool = True, + doc_to_text: Callable, + doc_to_target: Optional[Callable] = None, + doc_to_choice: Optional[Callable] = None, + doc_to_visual: Optional[Callable] = None, + target_delimiter="\n", + delimiter="\n", + image_tokens="", + role_question="USER: ", + role_answer="ASSISTANT: ", + config=None, + ): + self.data_frame: datasets.Dataset = data_frame + self.index = index + self.target_delimiter = target_delimiter + self.doc_to_text = doc_to_text + self.doc_to_target = doc_to_target + self.doc_to_choice = doc_to_choice + self.delimiter = delimiter + if doc_to_visual: + self.vision = LazyLoadedImages(data_frame, index, doc_to_visual, image_tokens) + else: + self.vision = None + self.role_question = role_question + self.role_answer = role_answer + if doc is None: + doc = data_frame[index] + self.config = config + self.question = self._get_question(doc) + self.answer = self._get_target(doc) if include_answer else None + + def _get_question(self, doc): + text = self.doc_to_text(doc) + return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text] + + def _get_target(self, doc): + return ( + str(self.doc_to_target(doc)[0]) + if type(self.doc_to_target(doc)) is list + else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) + ) + + def get_text(self): + if self.answer is None: + return self.role_question + self.question + self.delimiter + else: + return self.role_question + self.question + self.delimiter + self.role_answer + self.answer + + def __str__(self): + return self.get_text() + + def get_visions(self): + if self.vision: + return self.vision.get_images() + else: + return [] + + def already_have_image_token(self, image_token): + return image_token in self.question or (self.answer and image_token in self.answer) + + def num_images(self): + if self.vision: + return self.vision.get_num_images() + else: + return 0 + + def get_question_list(self, image_token, image_in_front=False): + questions = [] + visions = self.get_visions() + if visions and self.already_have_image_token(image_token): + q_list = self.question.split(image_token) + for q, img in zip(q_list[:-1], visions): + if q != "": + questions.append(q) + questions.append(img) + if q_list[-1] != "": + questions.append(q_list[-1]) + else: + if image_in_front and visions: + questions.extend(visions) + questions.append(self.question) + if not image_in_front and visions: + questions.extend(visions) + return questions + + +class Context(object): + def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n", description=None): + self.task = task + self.config = task._config + + self.doc_to_visual = self.task.doc_to_visual + self.doc_to_text = self.task.doc_to_text + self.doc_to_target = self.task.doc_to_target + self.doc_to_choice = self.task.doc_to_choice + + self.target_delimiter = target_delimiter + self.few_shot_delimiter = few_shot_delimiter + + self.contexts: List[QAPairs] = [] + + self.description = description + + def add_description(self, description): + self.description = description + + def add_in_context_example(self, doc, data_frame, index): + # question = self.get_question(doc) + # if data_frame and index: + # visual = LazyLoadedImages(data_frame, index, self.doc_to_visual) + # else: + # visual = None + # target = self.doc_to_target(doc) + # if visual: + # self.contexts.append(visual) + self.contexts.append( + QAPairs( + data_frame, + index, + doc=doc, + doc_to_text=self.doc_to_text, + doc_to_target=self.doc_to_target, + doc_to_choice=self.doc_to_choice, + doc_to_visual=self.doc_to_visual, + delimiter=self.target_delimiter, + config=self.config, + ) + ) + # self.contexts.append(self.few_shot_delimiter) + + def add_question(self, doc, data_frame=None, index=None): + # question = self.get_question(doc) + # if data_frame and index: + # visual = LazyLoadedImages(data_frame, index, self.doc_to_visual) + # else: + # visual = None + # if visual: + # self.contexts.append(visual) + self.contexts.append( + QAPairs( + data_frame, + index, + doc=doc, + doc_to_text=self.doc_to_text, + doc_to_target=self.doc_to_target, + doc_to_choice=self.doc_to_choice, + doc_to_visual=self.doc_to_visual, + delimiter=self.target_delimiter, + include_answer=False, + config=self.config, + ) + ) + # self.contexts.append(self.target_delimiter) + + def already_have_image_token(self, image_token): + for context in self.contexts: + if context.already_have_image_token(image_token): + return True + return False + + def get_text(self): + texts = [] + for context in self.contexts: + texts.append(str(context)) + return "".join(texts) + + def get_visions(self): + visions = [] + for context in self.contexts: + visions.extend(context.get_visions()) + return visions + + def extend(self, context): + if isinstance(context, list): + self.contexts.extend(context) + elif isinstance(context, Context): + self.contexts.extend(context.contexts) + else: + raise ValueError(f"Cannot extend context with object of type {type(context)}") + + def append(self, context): + self.contexts.append(context) + + def __str__(self): + return self.get_text() + + def __lt__(self, other): + if not isinstance(other, Context): + return NotImplemented + return self.get_text() < other.get_text() + + +class FewShotDataset(object): + def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str = None, split: str = None, dataset_kwargs: dict = None, same_as_eval: bool = False): + if dataset is not None and (dataset_path is not None or dataset_name is not None or split is not None or dataset_kwargs is not None): + raise ValueError("Cannot provide both `dataset` and other dataset arguments!") + self.dataset_path = dataset_path + self.dataset_name = dataset_name + self.split = split + self.dataset = dataset + self.dataset_kwargs = dataset_kwargs if dataset_kwargs is not None else {} + self.same_as_eval = same_as_eval + self.fewshot_indices = None + + def get_dataset(self) -> datasets.Dataset: + if self.dataset is None: + self.dataset = datasets.load_dataset(path=self.dataset_path, name=self.dataset_name, split=self.split, download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, **self.dataset_kwargs) + if self.fewshot_indices: + self.dataset = self.dataset.select(self.fewshot_indices) + return self.dataset + + def sample(self, n, rnd): + indices = rnd.sample(range(len(self.get_dataset())), n) + return indices, self.get_dataset().select(indices) + + def __getitem__(self, item): + return self.get_dataset()[item] + + class ContextSampler: - def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: + def __init__(self, docs: FewShotDataset, task, fewshot_indices=None, rnd=None) -> None: self.rnd = rnd assert self.rnd, "must pass rnd to FewShotSampler!" @@ -13,37 +282,25 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: self.doc_to_target = self.task.doc_to_target self.doc_to_choice = self.task.doc_to_choice - self.docs = docs # HF dataset split, provided by task._fewshot_docs() + self.docs: FewShotDataset = docs # HF dataset split, provided by task._fewshot_docs() if fewshot_indices: # subset few-shot docs from - self.docs = self.docs.select(fewshot_indices) + self.docs.fewshot_indices = fewshot_indices - def get_context(self, doc, num_fewshot): + def get_context(self, doc, num_fewshot) -> Context: # draw an extra fewshot sample if using same split as evaluating on - n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot + n_samples = num_fewshot + 1 if self.docs.same_as_eval else num_fewshot # draw `n_samples` docs from fewshot_docs - fewshotex = self.sample(n_samples) + indices, fewshotex = self.sample(n_samples) # get rid of the doc that's the one we're evaluating, if it's in the fewshot # TODO: should we just stop people from using fewshot from same split as evaluating? - selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] - - labeled_examples = ( - self.fewshot_delimiter.join( - [ - # TODO: is separating doc_to_text and doc_to_target by one space always desired? - (self.doc_to_text(doc) if (self.config.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)]) - + self.target_delimiter - + ( - str(self.doc_to_target(doc)[0]) - if type(self.doc_to_target(doc)) is list - else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) - ) - for doc in selected_docs - ] - ) - + self.fewshot_delimiter - ) + selected_docs = [(idx, x) for idx, x in zip(indices, fewshotex) if x != doc][:num_fewshot] + + labeled_examples = Context(self.task, self.fewshot_delimiter, self.target_delimiter) + + for idx, doc in selected_docs: + labeled_examples.add_in_context_example(doc, self.docs, idx) return labeled_examples @@ -52,7 +309,7 @@ def sample(self, n): Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. """ - return self.rnd.sample(self.docs, n) + return self.docs.sample(n, self.rnd) class FirstNSampler(ContextSampler): diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 57305a53..96b3dc97 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -22,6 +22,7 @@ from lmms_eval import utils from lmms_eval.api import samplers from lmms_eval.api.instance import Instance +from lmms_eval.api.samplers import FewShotDataset, Context from lmms_eval.filters import build_filter_ensemble from lmms_eval.api.registry import ( @@ -64,7 +65,7 @@ class TaskConfig(dict): training_split: str = None validation_split: str = None test_split: str = None - fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) + # fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) # formatting / prompting options. # see docs/advanced_task_guide.md for more info process_docs: Callable = None @@ -305,13 +306,13 @@ def fewshot_docs(self): A iterable of any object, that doc_to_text can handle """ if self.has_training_docs(): - return self.training_docs() + return FewShotDataset(self.training_docs()) elif self.has_validation_docs(): - return self.validation_docs() + return FewShotDataset(self.validation_docs()) else: if self.config.num_fewshot is not None: eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.") - return self.test_docs() + return FewShotDataset(self.test_docs(), same_as_eval=True) def _process_doc(self, doc): """ @@ -550,8 +551,20 @@ def __init__(self, model_name) -> None: # TODO no super() call here self._filters.append(filter_pipeline) else: self._filters = [build_filter_ensemble("none", [["take_first", None]])] + ########################################## + # # TODO: for test, will delete later + # if self.config.task == "textvqa_test": + # pass + # else: + # pass + ########################################### if self.config.fewshot_config is not None: - self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234)) + try: + random_seed = self.config.fewshot_config.get("random_seed", 1234) + sampler_constructor = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default") + self.sampler = sampler_constructor(self.fewshot_docs(), self, rnd=random.Random(random_seed)) + except Exception as e: + eval_logger.error(f"Error in fewshot_config: {e}") if self.has_test_docs(): self.task_docs = self.test_docs() @@ -739,8 +752,19 @@ def test_docs(self) -> datasets.Dataset: return self.dataset[self.config.test_split] def fewshot_docs(self): - if self.config.fewshot_split is not None: - return self.dataset[self.config.fewshot_split] + + if "fewshot_dataset" in self.config.fewshot_config: + fewshot_dataset_config = self.config.fewshot_config["fewshot_dataset"] + if "dataset_path" not in fewshot_dataset_config: + fewshot_dataset_config["dataset_path"] = self.config.dataset_path + same_as_eval = self.config.dataset_path == fewshot_dataset_config["dataset_path"] and self.config.dataset_name == fewshot_dataset_config.get("dataset_name", None) and self.config.test_split == fewshot_dataset_config["split"] + return FewShotDataset( + dataset_path=fewshot_dataset_config["dataset_path"], + dataset_name=fewshot_dataset_config.get("dataset_name", None), + split=fewshot_dataset_config["split"], + dataset_kwargs=fewshot_dataset_config.get("dataset_kwargs", {}), + same_as_eval=same_as_eval, + ) else: if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): eval_logger.warning(f"Task '{self.config.task}': " "num_fewshot > 0 but fewshot_split is None. " "using preconfigured rule.") @@ -758,23 +782,14 @@ def fewshot_context(self, doc_id, num_fewshot, split): :returns: str The fewshot context. """ + from lmms_eval.api.samplers import Context + doc = self.dataset_no_image[split][doc_id] - if num_fewshot == 0: - # always prepend the (possibly empty) task description - labeled_examples = self.config.description - else: - labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot) - example = self.doc_to_text(doc) - if type(example) == str: - return labeled_examples + example - elif type(example) == list: - return [labeled_examples + ex for ex in example] - elif type(example) == int: - if self.config.doc_to_choice is not None: - choices = self.doc_to_choice(doc) - return labeled_examples + choices[example] - else: - return labeled_examples + str(example) + labeled_examples = Context(self, self.config.fewshot_delimiter, self.config.target_delimiter, self.config.description) + if num_fewshot != 0: + labeled_examples.extend(self.sampler.get_context(doc, num_fewshot)) + labeled_examples.add_question(doc, self.test_docs(), doc_id) + return labeled_examples def apply_filters(self): if hasattr(self, "_filters"): @@ -917,7 +932,7 @@ def doc_to_choice(self, doc: Any) -> List[str]: else: raise TypeError - def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]: + def construct_requests(self, doc_id: int, ctx: Context, **kwargs) -> Union[List[Instance], Instance]: split = kwargs.get("split") kwargs.pop("split") if self.OUTPUT_TYPE == "loglikelihood": diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index d2ec2025..9116f9f8 100644 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -12,6 +12,7 @@ from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model from lmms_eval import utils +from lmms_eval.api.samplers import Context from PIL import Image @@ -65,33 +66,54 @@ def generate_until(self, requests) -> List[str]: for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: # encode, pad, and truncate contexts for this batch - visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] - visuals = self.flatten(visuals) - imgs = [] - for visual in visuals: - img = self.encode_image(visual) - imgs.append(img) - - payload = {"model": "gpt-4-vision-preview", "messages": []} - response_json = {"role": "user", "content": []} + # visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + # visuals = contexts.get_visions() + # visuals = self.flatten(visuals) + # imgs = [] + # for visual in visuals: + # img = self.encode_image(visual) + # imgs.append(img) + + # response_json = {"role": "user", "content": []} # When there is no image token in the context, append the image to the text - if self.image_token not in contexts: - payload["messages"].append(deepcopy(response_json)) - payload["messages"][0]["content"].append({"type": "text", "text": contexts}) - for img in imgs: - payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) - else: - contexts = contexts.split(self.image_token) - for idx, img in enumerate(imgs): - payload["messages"].append(deepcopy(response_json)) - payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) - payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) - - # If n image tokens are in the contexts - # contexts will be splitted into n+1 chunks - # Manually add it into the payload - payload["messages"].append(deepcopy(response_json)) - payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]}) + messages = [] + if contexts.description: + messages.append({"type": "system", "text": contexts.description}) + + for qa in contexts.contexts: + # content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) + # payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) + content = [] + questions = qa.get_question_list(self.image_token) + for q in questions: + if isinstance(q, str): + content.append({"type": "text", "text": q}) + elif isinstance(q, Image.Image): + img = self.encode_image(q) + content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) + messages.append({"role": "user", "content": content}) + if qa.answer: + messages.append({"role": "assistant", "content": [{"type": "text", "text": qa.answer}]}) + + payload = {"model": "gpt-4-vision-preview", "messages": messages} + + # if image_token_in_context: + # payload["messages"].append(deepcopy(response_json)) + # payload["messages"][0]["content"].append({"type": "text", "text": contexts}) + # for img in imgs: + # payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) + # else: + # contexts = contexts.split(self.image_token) + # for idx, img in enumerate(imgs): + # payload["messages"].append(deepcopy(response_json)) + # payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) + # payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) + + # If n image tokens are in the contexts + # contexts will be splitted into n+1 chunks + # Manually add it into the payload + # payload["messages"].append(deepcopy(response_json)) + # payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]}) if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 @@ -109,6 +131,9 @@ def generate_until(self, requests) -> List[str]: try: response = url_requests.post(API_URL, headers=headers, json=payload, timeout=20) response_data = response.json() + + if "error" in response_data: + raise Exception(f"Error: {response_data['error']['message']}") content = response_data["choices"][0]["message"]["content"].strip() break # If successful, break out of the loop diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index a735f82c..0671dd15 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -73,6 +73,12 @@ def __init__( self.device_map = device_map self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2) + if self._image_processor is None: + vision_tower = self._model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(device=device, dtype=torch.float16) + self._image_processor = vision_tower.image_processor self._config = self._model.config self.model.eval() self.model.tie_weights() @@ -82,7 +88,7 @@ def __init__( self.use_cache = use_cache self.truncate_context = truncate_context # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." - if accelerator.num_processes > 1 and device_map == "": + if accelerator.num_processes > 1 and device_map == "auto": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works @@ -256,7 +262,7 @@ def _collate(x): # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end - toks = self.tok_encode(x[0]) + toks = self.tok_encode(str(x[0])) return -len(toks), x[0] # we group requests by their generation_kwargs, @@ -270,8 +276,20 @@ def _collate(x): contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) task = task[0] split = split[0] - visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] - visuals = self.flatten(visuals) + # batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] + batched_visuals = [context.get_visions() for context in contexts] # [B, N] + # contexts_texts, batched_visuals = zip(*[context.get_text(image_tokens=DEFAULT_IMAGE_TOKEN, lazy=False) for context in contexts]) # [B, N] + flattened_visuals = self.flatten(batched_visuals) # [B*N] + # batched_visuals = context.get_visions() # [B, N] + # flattened_visuals = contexts[0].get_visions() # [B*N] + ############### for debugging ################### + # TODO: remove this block + # if len(visuals) > 1: + # for i in range(len(visuals)): + # path = f"./logs/llava/{i}.png" + # visuals[i].save(path) + # pass + ################################################# # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] @@ -292,8 +310,8 @@ def _collate(x): self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio") eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}") # encode, pad, and truncate contexts for this batch - if visuals: - image_tensor = process_images(visuals, self._image_processor, self._config) + if flattened_visuals: + image_tensor = process_images(flattened_visuals, self._image_processor, self._config) if type(image_tensor) is list: image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] else: @@ -305,41 +323,62 @@ def _collate(x): question_input = [] - for visual, context in zip(visuals, contexts): - if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: - """ - Three senarios: - 1. No image, and there for, no image token should be added. - 2. image token is already specified in the context, so we don't need to add it. - 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. - """ - image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN] - image_tokens = " ".join(image_tokens) - question = image_tokens + "\n" + context - else: - question = context + for context in contexts: + # if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: + # """ + # Three senarios: + # 1. No image, and there for, no image token should be added. + # 2. image token is already specified in the context, so we don't need to add it. + # 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. + # """ + # image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN] + # image_tokens = " ".join(image_tokens) + # if isinstance(context, list): + # context = "".join(context) + # question = image_tokens + "\n" + context + # else: + # question = context conv = conv_templates[self.conv_template].copy() - conv.append_message(conv.roles[0], question) - conv.append_message(conv.roles[1], None) + + from lmms_eval.api.samplers import LazyLoadedImages, QAPairs + + already_have_image_token = context.already_have_image_token(DEFAULT_IMAGE_TOKEN) + + for obj in context.contexts: + if already_have_image_token or obj.num_images() == 0: + question = obj.question + else: + question = " ".join(obj.num_images() * [DEFAULT_IMAGE_TOKEN]) + "\n" + obj.question + if context.description: + question = context.description + "\n" + question + answer = obj.answer + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], answer) + + # conv.append_message(conv.roles[0], question) + # conv.append_message(conv.roles[1], None) prompt_question = conv.get_prompt() question_input.append(prompt_question) # The above for loop has bugs. When there is no visuals, e.g. pure text, # there will be no for loop execute resulting in an empty question_input (because no visuals) # Scenario 1 won't even be execute - if len(visuals) == 0: - for context in contexts: - question = context - conv = conv_templates[self.conv_template].copy() - conv.append_message(conv.roles[0], question) - conv.append_message(conv.roles[1], None) - prompt_question = conv.get_prompt() - question_input.append(prompt_question) + # if len(flattened_visuals) == 0: + # for context in contexts: + # question = context + # conv = conv_templates[self.conv_template].copy() + # conv.append_message(conv.roles[0], question) + # conv.append_message(conv.roles[1], None) + # try: + # prompt_question = conv.get_prompt() + # except Exception as e: + # pass + # question_input.append(prompt_question) # input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) # preconfigure gen_kwargs with defaults - gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))] if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 if "temperature" not in gen_kwargs: @@ -372,6 +411,7 @@ def _collate(x): text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) except Exception as e: eval_logger.error(f"Error {e} in generating") + e.with_traceback() cont = "" text_outputs = [""] diff --git a/lmms_eval/tasks/docvqa/_default_template_docvqa_yaml b/lmms_eval/tasks/docvqa/_default_template_docvqa_yaml index 6e0cab68..920b89d1 100644 --- a/lmms_eval/tasks/docvqa/_default_template_docvqa_yaml +++ b/lmms_eval/tasks/docvqa/_default_template_docvqa_yaml @@ -17,3 +17,9 @@ model_specific_prompt_kwargs: qwen_vl: pre_prompt: "" post_prompt: " Answer:" +fewshot_config: + fewshot_sampler: default + fewshot_dataset: + dataset_path: lmms-lab/DocVQA + dataset_name: DocVQA + split: validation diff --git a/lmms_eval/tasks/flickr30k/flickr30k_test.yaml b/lmms_eval/tasks/flickr30k/flickr30k_test.yaml index 737d9ff4..776eae39 100644 --- a/lmms_eval/tasks/flickr30k/flickr30k_test.yaml +++ b/lmms_eval/tasks/flickr30k/flickr30k_test.yaml @@ -1,7 +1,7 @@ dataset_path: lmms-lab/flickr30k dataset_kwargs: token: True -task : "flickr30k_test" +task: "flickr30k_test" test_split: test output_type: generate_until doc_to_visual: !function utils.flickr_doc_to_visual @@ -40,5 +40,12 @@ metric_list: #- metric: flickr_SPICE # aggregation : !function utils.flickr_spice # higher_is_better : true +fewshot_config: + fewshot_sampler: default + fewshot_dataset: + dataset_path: lmms-lab/flickr30k + # dataset_name: + split: train + # random_seed: 1234 metadata: - version: 0.0 \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/cn_utils.py b/lmms_eval/tasks/olympiadbench/cn_utils.py index 34e5ce4d..628d51da 100644 --- a/lmms_eval/tasks/olympiadbench/cn_utils.py +++ b/lmms_eval/tasks/olympiadbench/cn_utils.py @@ -5,14 +5,17 @@ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file import logging + eval_logger = logging.getLogger("lmms-eval") dir_name = os.path.dirname(os.path.abspath(__file__)) olympiadbench_evaluator = OlympiadBenchEvaluator() + def olympiadbench_doc_to_visual(doc): return [image.convert("RGB") for image in doc["images"]] + def olympiadbench_doc_to_text(doc): question = doc["question"] subject = doc["subfield"] @@ -36,28 +39,26 @@ def olympiadbench_doc_to_text(doc): else: post_prompt += '"所以最终答案是\\boxed{用英⽂逗号连接的多个答案}。"\n' - final_question = pre_prompt + question + '\n' + post_prompt + final_question = pre_prompt + question + "\n" + post_prompt return final_question + def olympiadbench_process_results(doc, results): precision = doc["error"] - is_proving = "TP" in doc["source"] + is_proving = "TP" in doc["source"] if precision is None: precision = 0 prediction = results[0].strip() if is_proving: - return { - "submission": prediction - } + return {"submission": prediction} else: prediction = prediction.split("所以最终答案是")[-1] prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。") accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision) accuracy = int(accuracy) - return { - "exact_match": accuracy - } + return {"exact_match": accuracy} + def olympiadbench_aggregate_results(results, args): now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S") @@ -66,4 +67,3 @@ def olympiadbench_aggregate_results(results, args): with open(path, "w") as f: json.dump(results, f, ensure_ascii=False) print(f"Submission file saved to {path}") - \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/en_utils.py b/lmms_eval/tasks/olympiadbench/en_utils.py index a21ee159..4b165e38 100644 --- a/lmms_eval/tasks/olympiadbench/en_utils.py +++ b/lmms_eval/tasks/olympiadbench/en_utils.py @@ -5,14 +5,17 @@ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file import logging + eval_logger = logging.getLogger("lmms-eval") dir_name = os.path.dirname(os.path.abspath(__file__)) olympiadbench_evaluator = OlympiadBenchEvaluator() + def olympiadbench_doc_to_visual(doc): return [image.convert("RGB") for image in doc["images"]] + def olympiadbench_doc_to_text(doc): question = doc["question"] subject = doc["subfield"] @@ -30,34 +33,34 @@ def olympiadbench_doc_to_text(doc): post_prompt += f"The answer of the question should be {ans_type}.\n" else: post_prompt += f"The question has multiple answers, each of them should be {ans_type}.\n" - post_prompt += "Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with " + post_prompt += ( + "Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with " + ) if not mul_ans: post_prompt += '"So the final answer is \\boxed{answer}."\n' else: - post_prompt += 'So the final answer is \\boxed{multiple answers connected with commas}.\n' + post_prompt += "So the final answer is \\boxed{multiple answers connected with commas}.\n" - final_question = pre_prompt + question + '\n' + post_prompt + final_question = pre_prompt + question + "\n" + post_prompt return final_question + def olympiadbench_process_results(doc, results): precision = doc["error"] - is_proving = "TP" in doc["source"] + is_proving = "TP" in doc["source"] if precision is None: precision = 0 prediction = results[0].strip() if is_proving: - return { - "submission": prediction - } + return {"submission": prediction} else: prediction = prediction.split("final answer is")[-1] prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。") accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision) accuracy = int(accuracy) - return { - "exact_match": accuracy - } + return {"exact_match": accuracy} + def olympiadbench_aggregate_results(results, args): now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S") @@ -66,4 +69,3 @@ def olympiadbench_aggregate_results(results, args): with open(path, "w") as f: json.dump(results, f, ensure_ascii=False) print(f"Submission file saved to {path}") - \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py b/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py index dd40f611..5ae36883 100644 --- a/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py +++ b/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py @@ -11,6 +11,7 @@ # precision = 1e-4 # res = scorer.judge(exp1, exp2, precision) + class OlympiadBenchEvaluator: def __init__(self): # Map of special symbols to their replacements @@ -46,8 +47,8 @@ def split_by_comma(self, expr: str): start_idx = i + 1 if start_idx < len(expr): - splitted_expr.append(expr[start_idx:].strip()) - + splitted_expr.append(expr[start_idx:].strip()) + return splitted_expr def trans_plus_minus_sign(self, expr_list: list): @@ -59,9 +60,9 @@ def trans_plus_minus_sign(self, expr_list: list): new_expr_list.append(expr.replace("\\pm", "-")) else: new_expr_list.append(expr) - + return new_expr_list - + def judge(self, expression1, expression2, precision=1e-8): # Judge if two expressions are equal (expression1 is considered as the Ground Truth) # Default precision is a list for supporting multiple expressions @@ -74,11 +75,11 @@ def judge(self, expression1, expression2, precision=1e-8): if expression1 == expression2: # print("Exactly equal") return True - + # Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered - expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1) - expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2) - + expression1 = re.sub(r"[\u4e00-\u9fff]+", "", expression1) + expression2 = re.sub(r"[\u4e00-\u9fff]+", "", expression2) + expression1 = self.split_by_comma(expression1) expression2 = self.split_by_comma(expression2) @@ -88,7 +89,7 @@ def judge(self, expression1, expression2, precision=1e-8): # Set up a list for allowed errors if len(precision) <= 1: precision = precision * len(temp_list1) - + if len(temp_list1) != len(temp_list2): return False @@ -112,7 +113,7 @@ def judge(self, expression1, expression2, precision=1e-8): # If all elements are matched, return True return True - + def is_interval(self, expr): # Checks if an expression is an interval return expr.startswith(("(", "[")) and expr.endswith((")", "]")) @@ -120,7 +121,7 @@ def is_interval(self, expr): def sympy_sub_pi(self, expression_sympy): # Replaces the symbol for pi in sympy expressions with its numerical value return expression_sympy.subs(self.pi, math.pi) - + def is_equal(self, expression1, expression2): # Default first expression is ground truth. Check if expressions are equal in different aspects if expression1 == expression2 and expression1 != "" and expression2 != "": @@ -143,7 +144,7 @@ def is_equal(self, expression1, expression2): return True except: pass - + # Then check if expressions are mathematically equal try: if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2): @@ -151,7 +152,7 @@ def is_equal(self, expression1, expression2): return True except: pass - + # Lastly, check for equation equality try: if self.equation_equal(expression1, expression2): @@ -159,7 +160,7 @@ def is_equal(self, expression1, expression2): return True except: pass - + return False def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True): @@ -167,17 +168,16 @@ def numerical_equal(self, expression1: str, expression2: str, include_percentage # Includes possible percentage cases reference = float(expression1) prediction = float(expression2) - + if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] - + for item in gt_result: if abs(item - prediction) <= self.precision * 1.01: return True return False - def expression_equal(self, exp1, exp2): # Check if two expressions are mathematically equivalent @@ -186,7 +186,7 @@ def extract_expression(expression): if "=" in expression: expression = expression.split("=")[1] return expression.strip() - + exp1 = extract_expression(exp1) exp2 = extract_expression(exp2) @@ -204,7 +204,7 @@ def extract_expression(expression): elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol): try: if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)): - print(f"These two numbers cannot be calculated by the current computer for: \"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"") + print(f'These two numbers cannot be calculated by the current computer for: "{str(expr1_sym)}" and "{str(expr2_sym)}"') return False if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01: @@ -218,7 +218,7 @@ def extract_expression(expression): simplified_expr = simplify(expr1_sym - expr2_sym) num_value = simplified_expr.evalf() - + return abs(num_value) < 1e-3 except: return False @@ -227,7 +227,7 @@ def equation_equal(self, expression1, expression2): # Check if two equations are mathematically equivalent # Simplify equations and use sympy for equivalence checking def simplify_equation(latex_eq): - lhs, rhs = latex_eq.split('=') + lhs, rhs = latex_eq.split("=") lhs_expr = parse_latex(lhs) rhs_expr = parse_latex(rhs) @@ -254,18 +254,18 @@ def interval_equal(self, expression1, expression2): def compare_two_interval(inter1, inter2): if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]: return False - - inter1 = inter1.strip('[]()') - inter2 = inter2.strip('[]()') - items_1 = inter1.split(',') - items_2 = inter2.split(',') + inter1 = inter1.strip("[]()") + inter2 = inter2.strip("[]()") + + items_1 = inter1.split(",") + items_2 = inter2.split(",") for item_1, item_2 in zip(items_1, items_2): if not self.expression_equal(item_1, item_2): return False return True - + interval1 = expression1 interval2 = expression2 @@ -274,7 +274,7 @@ def compare_two_interval(inter1, inter2): else: inter_list1 = interval1.split("\\cup") inter_list2 = interval2.split("\\cup") - + if len(inter_list1) != len(inter_list2): return False else: @@ -286,7 +286,7 @@ def compare_two_interval(inter1, inter2): def preprocess(self, expression1, expression2): # Preprocess expressions to extract and replace special symbols def extract_boxed_content(latex_str): - boxed_matches = re.finditer(r'\\boxed{', latex_str) + boxed_matches = re.finditer(r"\\boxed{", latex_str) results = "" for match in boxed_matches: @@ -295,14 +295,14 @@ def extract_boxed_content(latex_str): stack = 1 while stack > 0 and end_index < len(latex_str): - if latex_str[end_index] == '{': + if latex_str[end_index] == "{": stack += 1 - elif latex_str[end_index] == '}': + elif latex_str[end_index] == "}": stack -= 1 end_index += 1 if stack == 0: - content = latex_str[start_index:end_index - 1] + content = latex_str[start_index : end_index - 1] results += content + "," else: raise ValueError("Mismatched braces in LaTeX string.") @@ -317,28 +317,28 @@ def extract_boxed_content(latex_str): results += ans + "," else: results = latex_str - + return results - + def sepcial_symbol_replace(expression): if "\\in " in expression: expression = expression.split("\\in ")[1] - + for signal in self.special_signal_map: expression = expression.replace(signal, self.special_signal_map[signal]) expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~,。") - pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}' - expression = re.sub(pattern, r'\1', expression) + pattern = r"\\(?:mathrm|mathbf)\{~?([^}]*)\}" + expression = re.sub(pattern, r"\1", expression) return expression - + exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2) exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2) return exp1, exp2 - + def can_compute_power(self, expr): # Checks if a power expression can be computed if isinstance(expr, Pow): @@ -352,4 +352,4 @@ def can_compute_power(self, expr): else: return False else: - return True # Not a power expression, can compute \ No newline at end of file + return True # Not a power expression, can compute diff --git a/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml b/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml index 282b401b..5c6205e9 100644 --- a/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml +++ b/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml @@ -2,7 +2,7 @@ dataset_path: lmms-lab/textvqa output_type: generate_until doc_to_visual: !function utils.textvqa_doc_to_visual doc_to_text: !function utils.textvqa_doc_to_text -doc_to_target: "answer" +doc_to_target: !function utils.textvqa_doc_to_target generation_kwargs: until: - "ASSISTANT:" @@ -15,3 +15,10 @@ model_specific_prompt_kwargs: qwen_vl: pre_prompt: "" post_prompt: " Answer:" +fewshot_config: + fewshot_sampler: default + fewshot_dataset: + # dataset_path: lmms-lab/flickr30k + # dataset_name: + split: train + # random_seed: 1234 diff --git a/lmms_eval/tasks/textvqa/utils.py b/lmms_eval/tasks/textvqa/utils.py index ea3b503b..7a1ef194 100644 --- a/lmms_eval/tasks/textvqa/utils.py +++ b/lmms_eval/tasks/textvqa/utils.py @@ -1,5 +1,6 @@ import re import os +import random import json import yaml import pathlib @@ -66,3 +67,8 @@ def textvqa_aggreate_submissions(results, args): json.dump(results, f) # print(f"Submission file saved to {path}") eval_logger.info(f"Submission file saved to {path}") + + +def textvqa_doc_to_target(doc): + answers = doc["answers"] + return random.choice(answers) if isinstance(answers, list) else answers diff --git a/lmms_eval/utils.py b/lmms_eval/utils.py index d6c08553..93ecf0db 100644 --- a/lmms_eval/utils.py +++ b/lmms_eval/utils.py @@ -822,6 +822,8 @@ def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List: Yields: List: Yields reordered elements one by one. """ + from lmms_eval.api.samplers import Context + arr = sorted(arr, key=lambda x: self.fn(x[1])) self.reorder_indices.extend([x[0] for x in arr]) yield from [x[1] for x in arr] diff --git a/pyproject.toml b/pyproject.toml index c50c4e76..761575da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "tiktoken", "pre-commit", "pydantic", + "antlr4-python3-runtime==4.11", ] [tool.setuptools.packages.find]