From c5abe5723c06499696e1d7cedfba78da7a852625 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Mon, 28 Oct 2024 02:32:25 +0000 Subject: [PATCH] feat: add multimdoal fewshot evaluation --- lmms_eval/api/samplers.py | 47 +++++++++++++++ lmms_eval/api/task.py | 121 ++++++++++++++++++++++---------------- 2 files changed, 118 insertions(+), 50 deletions(-) diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index 2cecfe22..56b0de4a 100755 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -12,11 +12,58 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: self.doc_to_text = self.task.doc_to_text self.doc_to_target = self.task.doc_to_target self.doc_to_choice = self.task.doc_to_choice + self.doc_to_visual = self.task.doc_to_visual self.docs = docs # HF dataset split, provided by task._fewshot_docs() if fewshot_indices: # subset few-shot docs from self.docs = self.docs.select(fewshot_indices) + def get_multimodal_context(self, doc, num_fewshot): + n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = self.fewshot_delimiter.join([self._format_single_example(doc) for doc in selected_docs]) + self.fewshot_delimiter + + # str, list of PIL Images + return labeled_examples, self._collect_images(selected_docs) + + def _format_single_example(self, doc): + # Replace actual image content with placeholder + image_placeholder = f"" + + # Get text content (question/prompt) + text_content = self.doc_to_text(doc) if (self.config.doc_to_choice is None or isinstance(self.doc_to_text(doc), str)) else self.doc_to_choice(doc)[self.doc_to_text(doc)] + + # Get target/label + target = self._format_target(doc) + + # Combine with image placeholder + return f"{image_placeholder}\n{text_content}{self.target_delimiter}{target}" + + def _collect_images(self, docs): + # Create a dictionary mapping image placeholders to actual PIL Images + image_list = [] + for doc in docs: + image = self.doc_to_visual(doc) # Assuming this is the PIL Image + image_list.append(image) + # flatten list of lists + image_list = [item for sublist in image_list for item in sublist] + return image_list + + def _format_target(self, doc): + target = self.doc_to_target(doc) + + if isinstance(target, list): + return str(target[0]) + elif self.config.doc_to_choice is None or isinstance(target, str): + return target + else: + return str(self.doc_to_choice(doc)[target]) + def get_context(self, doc, num_fewshot): # draw an extra fewshot sample if using same split as evaluating on n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index bad41dca..aa5e12d0 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -1034,19 +1034,14 @@ def concat_tar_parts(tar_parts, output_tar): if "create_link" in dataset_kwargs: dataset_kwargs.pop("create_link") - if "load_from_disk" in dataset_kwargs and dataset_kwargs["load_from_disk"]: - dataset_kwargs.pop("load_from_disk") - # using local task in offline environment, need to process the online dataset into local format via - # `ds = load_datasets("lmms-lab/MMMU")` - self.dataset = datasets.load_from_disk(path=self.DATASET_PATH, name=self.DATASET_NAME) - else: - self.dataset = datasets.load_dataset( - path=self.DATASET_PATH, - name=self.DATASET_NAME, - download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, - download_config=download_config, - **dataset_kwargs if dataset_kwargs is not None else {}, - ) + # Check if the key exists first + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, + download_config=download_config, + **dataset_kwargs if dataset_kwargs is not None else {}, + ) if self.config.process_docs is not None: for split in self.dataset: @@ -1114,6 +1109,7 @@ def fewshot_context( apply_chat_template: bool = False, fewshot_as_multiturn: bool = False, chat_template: Optional[Callable] = None, + is_multimodal: bool = False, ) -> str: """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. @@ -1162,48 +1158,73 @@ def fewshot_context( # if few-shot - append examples after the system prompt if num_fewshot > 0: - if apply_chat_template: - labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn)) + if is_multimodal is False: + if apply_chat_template: + labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn)) + else: + labeled_examples += self.sampler.get_context(doc, num_fewshot) else: - labeled_examples += self.sampler.get_context(doc, num_fewshot) + if apply_chat_template: + labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_chat_context(doc, num_fewshot, fewshot_as_multiturn) + labeled_examples.extend(labeled_examples_text) + else: + labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_context(doc, num_fewshot) + labeled_examples += labeled_examples_text example = self.doc_to_text(doc) - if apply_chat_template: - if self.multiple_input: + if is_multimodal is False: + if apply_chat_template: + if self.multiple_input: + return chat_template(labeled_examples) + if isinstance(example, str): + self.append_target_question(labeled_examples, example, fewshot_as_multiturn) + # for loglikelihood create a list of questions with appended choices + elif isinstance(example, list): + labeled_examples_list = [] + # copy chat history for each example and append the answer + for ex in example: + chat = copy.deepcopy(labeled_examples) + self.append_target_question(chat, ex, fewshot_as_multiturn) + labeled_examples_list.append(chat_template(chat)) + return labeled_examples_list + # if example is an integer, append the choice or convert to string + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn) + else: + self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn) + # return lm.apply_chat_template(labeled_examples) return chat_template(labeled_examples) - if isinstance(example, str): - self.append_target_question(labeled_examples, example, fewshot_as_multiturn) - # for loglikelihood create a list of questions with appended choices - elif isinstance(example, list): - labeled_examples_list = [] - # copy chat history for each example and append the answer - for ex in example: - chat = deepcopy(labeled_examples) - self.append_target_question(chat, ex, fewshot_as_multiturn) - labeled_examples_list.append(chat_template(chat)) - return labeled_examples_list - # if example is an integer, append the choice or convert to string - elif isinstance(example, int): - if self.config.doc_to_choice is not None: - choices = self.doc_to_choice(doc) - self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn) - else: - self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn) - # return lm.apply_chat_template(labeled_examples) - return chat_template(labeled_examples) + else: + if self.multiple_input: + return labeled_examples + if isinstance(example, str): + return labeled_examples + example + elif isinstance(example, list): + return [labeled_examples + ex for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + else: + return labeled_examples + str(example) else: - if self.multiple_input: - return labeled_examples - if isinstance(example, str): - return labeled_examples + example - elif isinstance(example, list): - return [labeled_examples + ex for ex in example] - elif isinstance(example, int): - if self.config.doc_to_choice is not None: - choices = self.doc_to_choice(doc) - return labeled_examples + choices[example] - else: - return labeled_examples + str(example) + if apply_chat_template: + raise NotImplementedError("Multimodal chat template not implemented yet") + else: + if self.multiple_input: + return labeled_examples, labeled_examples_multimodal + if isinstance(example, str): + return labeled_examples + example, labeled_examples_multimodal + elif isinstance(example, list): + return [labeled_examples + ex for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example], labeled_examples_multimodal + else: + return labeled_examples + str(example), labeled_examples_multimodal def apply_filters(self): if hasattr(self, "_filters"):