Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: add multimdoal fewshot evaluation #365

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Luodian
Copy link
Contributor

@Luodian Luodian commented Oct 28, 2024

May still need improvement to satify all situations and make sure backward compatible.

This pull request introduces support for multimodal contexts in the lmms_eval module, enhancing the ability to handle both text and visual data. The most important changes include adding methods to manage multimodal data in samplers.py and updating the fewshot_context function in task.py to accommodate these changes.

Enhancements for multimodal support:

  • [lmms_eval/api/samplers.py]: Added methods get_multimodal_context, _format_single_example, _collect_images, and _format_target to handle multimodal data, including text and visual content.
  • [lmms_eval/api/task.py]: Updated the fewshot_context function to include an is_multimodal parameter and logic to handle multimodal examples, including raising a NotImplementedError for multimodal chat templates.

Code improvements:

  • [lmms_eval/api/task.py]: Replaced deepcopy with copy.deepcopy for clarity and consistency. (`[lmms_eval/api/task.pyL1181-R1186]
  • [lmms_eval/api/task.py]: Simplified dataset loading by removing redundant checks for the load_from_disk key. (`[lmms_eval/api/task.pyL1037-R1037]

Fewshot Demo Yaml (textvqa)

task: textvqa_val_fewshot
test_split: validation
training_split: train
validation_split: validation
fewshot_split: train
fewshot_config:
  sampler: first_n
  
metric_list:
  - metric: exact_match
    aggregation: mean
    higher_is_better: true
    ignore_case: true
    ignore_punctuation: true

lmms_eval_specific_kwargs:
  default:
    pre_prompt: "Question: "
    post_prompt: " Short answer in one phrase or single word:"
    ocr: false
  qwen_vl:
    pre_prompt: ""
    post_prompt: " Answer:"
    
include: _default_template_textvqa_yaml

If you meet the lint warnings, you can use following scripts to reformat code.

pip install pre-commit
pre-commit install
pre-commit run --all-files

Thank you for your contributions!

@Luodian Luodian requested review from pufanyi and kcz358 October 28, 2024 02:35
@@ -1114,6 +1109,7 @@ def fewshot_context(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
is_multimodal: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base on my understanding on the changes on this few_shot_context calling, this is_multimodal currently is always False?

# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
chat_template,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not called in current lmms-eval standard pipeline, but in another project it's called via importing from lmms_eval.

The actual code is following.

visual = self.task_obj.doc_to_visual(self.hf_dataset[index])
if self.shot_number > 0 and (self.task_group_name in ["mmlu_flan_n_shot_generative"] or self.task_name in ["mmlu_flan_n_shot_local_all", "textvqa_val_fewshot", "textcaps_val_fewshot"]):
    if self.task_name in ["mmlu_flan_n_shot_local_all"]:
        is_multimodal = False
    else:
        is_multimodal = True
    context, multimodal_ctx = self.task_obj.fewshot_context(self.hf_dataset[index], num_fewshot=self.shot_number, system_instruction="", is_multimodal=is_multimodal)
    context = context.strip("\n\n")
    if isinstance(multimodal_ctx, list) and isinstance(visual, list):
        visual = multimodal_ctx + visual
    else:
        raise ValueError(f"multimodal_ctx: {multimodal_ctx} is not a list, please check the task_obj.fewshot_context")
else:
    context = self.task_obj.doc_to_text(self.hf_dataset[index])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I set it to WIP and wish if we can altogether refine this PR lol.

I think Fanyi @pufanyi did similar things many months ago.

Comment on lines +1167 to +1172
if apply_chat_template:
labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_chat_context(doc, num_fewshot, fewshot_as_multiturn)
labeled_examples.extend(labeled_examples_text)
else:
labeled_examples_text, labeled_examples_multimodal = self.sampler.get_multimodal_context(doc, num_fewshot)
labeled_examples += labeled_examples_text
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since is_multimodal is always False as it is not pass in when calling the few_shot_context will this get_multimodal_context ever get called? Seems like we can ignore the apply chat template here since it is never been used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants