Skip to content

Commit

Permalink
feat: add multimdoal fewshot evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Oct 28, 2024
1 parent bc2899c commit c5abe57
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 50 deletions.
47 changes: 47 additions & 0 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,58 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice
self.doc_to_visual = self.task.doc_to_visual

self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)

def get_multimodal_context(self, doc, num_fewshot):
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = self.fewshot_delimiter.join([self._format_single_example(doc) for doc in selected_docs]) + self.fewshot_delimiter

# str, list of PIL Images
return labeled_examples, self._collect_images(selected_docs)

def _format_single_example(self, doc):
# Replace actual image content with placeholder
image_placeholder = f"<image>"

# Get text content (question/prompt)
text_content = self.doc_to_text(doc) if (self.config.doc_to_choice is None or isinstance(self.doc_to_text(doc), str)) else self.doc_to_choice(doc)[self.doc_to_text(doc)]

# Get target/label
target = self._format_target(doc)

# Combine with image placeholder
return f"{image_placeholder}\n{text_content}{self.target_delimiter}{target}"

def _collect_images(self, docs):
# Create a dictionary mapping image placeholders to actual PIL Images
image_list = []
for doc in docs:
image = self.doc_to_visual(doc) # Assuming this is the PIL Image
image_list.append(image)
# flatten list of lists
image_list = [item for sublist in image_list for item in sublist]
return image_list

def _format_target(self, doc):
target = self.doc_to_target(doc)

if isinstance(target, list):
return str(target[0])
elif self.config.doc_to_choice is None or isinstance(target, str):
return target
else:
return str(self.doc_to_choice(doc)[target])

def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
Expand Down
121 changes: 71 additions & 50 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,19 +1034,14 @@ def concat_tar_parts(tar_parts, output_tar):
if "create_link" in dataset_kwargs:
dataset_kwargs.pop("create_link")

if "load_from_disk" in dataset_kwargs and dataset_kwargs["load_from_disk"]:
dataset_kwargs.pop("load_from_disk")
# using local task in offline environment, need to process the online dataset into local format via
# `ds = load_datasets("lmms-lab/MMMU")`
self.dataset = datasets.load_from_disk(path=self.DATASET_PATH, name=self.DATASET_NAME)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)
# Check if the key exists first
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)

if self.config.process_docs is not None:
for split in self.dataset:
Expand Down Expand Up @@ -1114,6 +1109,7 @@ def fewshot_context(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
is_multimodal: bool = False,
) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
Expand Down Expand Up @@ -1162,48 +1158,73 @@ def fewshot_context(

# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn))
if is_multimodal is False:
if apply_chat_template:
labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn))
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
if apply_chat_template:
labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_chat_context(doc, num_fewshot, fewshot_as_multiturn)
labeled_examples.extend(labeled_examples_text)
else:
labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_context(doc, num_fewshot)
labeled_examples += labeled_examples_text

example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
if is_multimodal is False:
if apply_chat_template:
if self.multiple_input:
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(labeled_examples, example, fewshot_as_multiturn)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = copy.deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn)
else:
self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(labeled_examples, example, fewshot_as_multiturn)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn)
else:
self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
else:
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
else:
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
if apply_chat_template:
raise NotImplementedError("Multimodal chat template not implemented yet")
else:
if self.multiple_input:
return labeled_examples, labeled_examples_multimodal
if isinstance(example, str):
return labeled_examples + example, labeled_examples_multimodal
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example], labeled_examples_multimodal
else:
return labeled_examples + str(example), labeled_examples_multimodal

def apply_filters(self):
if hasattr(self, "_filters"):
Expand Down

0 comments on commit c5abe57

Please sign in to comment.