Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: add multimdoal fewshot evaluation #365

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,58 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice
self.doc_to_visual = self.task.doc_to_visual

self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)

def get_multimodal_context(self, doc, num_fewshot):
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = self.fewshot_delimiter.join([self._format_single_example(doc) for doc in selected_docs]) + self.fewshot_delimiter

# str, list of PIL Images
return labeled_examples, self._collect_images(selected_docs)

def _format_single_example(self, doc):
# Replace actual image content with placeholder
image_placeholder = f"<image>"

# Get text content (question/prompt)
text_content = self.doc_to_text(doc) if (self.config.doc_to_choice is None or isinstance(self.doc_to_text(doc), str)) else self.doc_to_choice(doc)[self.doc_to_text(doc)]

# Get target/label
target = self._format_target(doc)

# Combine with image placeholder
return f"{image_placeholder}\n{text_content}{self.target_delimiter}{target}"

def _collect_images(self, docs):
# Create a dictionary mapping image placeholders to actual PIL Images
image_list = []
for doc in docs:
image = self.doc_to_visual(doc) # Assuming this is the PIL Image
image_list.append(image)
# flatten list of lists
image_list = [item for sublist in image_list for item in sublist]
return image_list

def _format_target(self, doc):
target = self.doc_to_target(doc)

if isinstance(target, list):
return str(target[0])
elif self.config.doc_to_choice is None or isinstance(target, str):
return target
else:
return str(self.doc_to_choice(doc)[target])

def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
Expand Down
121 changes: 71 additions & 50 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,19 +1034,14 @@ def concat_tar_parts(tar_parts, output_tar):
if "create_link" in dataset_kwargs:
dataset_kwargs.pop("create_link")

if "load_from_disk" in dataset_kwargs and dataset_kwargs["load_from_disk"]:
dataset_kwargs.pop("load_from_disk")
# using local task in offline environment, need to process the online dataset into local format via
# `ds = load_datasets("lmms-lab/MMMU")`
self.dataset = datasets.load_from_disk(path=self.DATASET_PATH, name=self.DATASET_NAME)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)
# Check if the key exists first
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)

if self.config.process_docs is not None:
for split in self.dataset:
Expand Down Expand Up @@ -1114,6 +1109,7 @@ def fewshot_context(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
is_multimodal: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base on my understanding on the changes on this few_shot_context calling, this is_multimodal currently is always False?

# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
chat_template,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not called in current lmms-eval standard pipeline, but in another project it's called via importing from lmms_eval.

The actual code is following.

visual = self.task_obj.doc_to_visual(self.hf_dataset[index])
if self.shot_number > 0 and (self.task_group_name in ["mmlu_flan_n_shot_generative"] or self.task_name in ["mmlu_flan_n_shot_local_all", "textvqa_val_fewshot", "textcaps_val_fewshot"]):
    if self.task_name in ["mmlu_flan_n_shot_local_all"]:
        is_multimodal = False
    else:
        is_multimodal = True
    context, multimodal_ctx = self.task_obj.fewshot_context(self.hf_dataset[index], num_fewshot=self.shot_number, system_instruction="", is_multimodal=is_multimodal)
    context = context.strip("\n\n")
    if isinstance(multimodal_ctx, list) and isinstance(visual, list):
        visual = multimodal_ctx + visual
    else:
        raise ValueError(f"multimodal_ctx: {multimodal_ctx} is not a list, please check the task_obj.fewshot_context")
else:
    context = self.task_obj.doc_to_text(self.hf_dataset[index])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I set it to WIP and wish if we can altogether refine this PR lol.

I think Fanyi @pufanyi did similar things many months ago.

) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
Expand Down Expand Up @@ -1162,48 +1158,73 @@ def fewshot_context(

# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn))
if is_multimodal is False:
if apply_chat_template:
labeled_examples.extend(self.sampler.get_chat_context(doc, num_fewshot, fewshot_as_multiturn))
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
if apply_chat_template:
labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_chat_context(doc, num_fewshot, fewshot_as_multiturn)
labeled_examples.extend(labeled_examples_text)
else:
labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_context(doc, num_fewshot)
labeled_examples += labeled_examples_text
Comment on lines +1167 to +1172
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since is_multimodal is always False as it is not pass in when calling the few_shot_context will this get_multimodal_context ever get called? Seems like we can ignore the apply chat template here since it is never been used.


example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
if is_multimodal is False:
if apply_chat_template:
if self.multiple_input:
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(labeled_examples, example, fewshot_as_multiturn)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = copy.deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn)
else:
self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(labeled_examples, example, fewshot_as_multiturn)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(labeled_examples, choices[example], fewshot_as_multiturn)
else:
self.append_target_question(labeled_examples, str(example), fewshot_as_multiturn)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
else:
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
else:
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
if apply_chat_template:
raise NotImplementedError("Multimodal chat template not implemented yet")
else:
if self.multiple_input:
return labeled_examples, labeled_examples_multimodal
if isinstance(example, str):
return labeled_examples + example, labeled_examples_multimodal
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example], labeled_examples_multimodal
else:
return labeled_examples + str(example), labeled_examples_multimodal

def apply_filters(self):
if hasattr(self, "_filters"):
Expand Down
Loading