-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
10be8c3
commit 8928a5a
Showing
9 changed files
with
494 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
lmms_eval_specific_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "\nAnswer the question using a single character from the given options." | ||
generation_kwargs: | ||
max_new_tokens: 8 | ||
metadata: | ||
version: 0.0 | ||
task_type: image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
group: mmt | ||
task: | ||
- mmt_val | ||
- mmt_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
group: mmt_mi | ||
task: | ||
- mmt_mi_val | ||
- mmt_mi_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
dataset_path: lmms-lab/MMT_MI-Benchmark | ||
dataset_kwargs: | ||
token: True | ||
task: "mmt_mi_test" | ||
test_split: test | ||
doc_to_visual: !function utils_mi.mmt_doc_to_visual | ||
doc_to_text: !function utils_mi.mmt_doc_to_text | ||
doc_to_choice: !function utils_mi.mmt_doc_to_choice | ||
doc_to_target: !function utils_mi.mmt_doc_to_target | ||
process_results: !function utils_mi.mmt_process_results | ||
metric_list: | ||
- metric: submission | ||
aggregation: !function utils_mi.mmt_test_aggregate_results_for_submission | ||
higher_is_better: true | ||
include: _default_template_yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
dataset_path: lmms-lab/MMT_MI-Benchmark | ||
dataset_kwargs: | ||
token: True | ||
task: "mmt_mi_val" | ||
test_split: val | ||
doc_to_visual: !function utils_mi.mmt_doc_to_visual | ||
doc_to_text: !function utils_mi.mmt_doc_to_text | ||
doc_to_choice: !function utils_mi.mmt_doc_to_choice | ||
doc_to_target: !function utils_mi.mmt_doc_to_target | ||
process_results: !function utils_mi.mmt_process_results | ||
metric_list: | ||
- metric: accuracy | ||
aggregation: !function utils_mi.mmt_aggregate_results | ||
higher_is_better: true | ||
include: _default_template_yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
dataset_path: lmms-lab/MMT-Benchmark | ||
dataset_kwargs: | ||
token: True | ||
task: "mmt_test" | ||
test_split: test | ||
doc_to_visual: !function utils.mmt_doc_to_visual | ||
doc_to_text: !function utils.mmt_doc_to_text | ||
doc_to_choice: !function utils.mmt_doc_to_choice | ||
doc_to_target: !function utils.mmt_doc_to_target | ||
process_results: !function utils.mmt_process_results | ||
metric_list: | ||
- metric: submission | ||
aggregation: !function utils.mmt_test_aggregate_results_for_submission | ||
higher_is_better: true | ||
include: _default_template_yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
dataset_path: lmms-lab/MMT-Benchmark | ||
dataset_kwargs: | ||
token: True | ||
task: "mmt_val" | ||
test_split: val | ||
doc_to_visual: !function utils.mmt_doc_to_visual | ||
doc_to_text: !function utils.mmt_doc_to_text | ||
doc_to_choice: !function utils.mmt_doc_to_choice | ||
doc_to_target: !function utils.mmt_doc_to_target | ||
process_results: !function utils.mmt_process_results | ||
metric_list: | ||
- metric: accuracy | ||
aggregation: !function utils.mmt_aggregate_results | ||
higher_is_better: true | ||
include: _default_template_yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
from PIL import Image | ||
from collections import defaultdict | ||
from loguru import logger as eval_logger | ||
import random | ||
import json | ||
import re | ||
import pandas as pd | ||
import numpy as np | ||
|
||
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file | ||
|
||
# All abbreviations for the categories | ||
MMT_abbrs = { | ||
'visual_recognition': 'VR', | ||
'localization': 'Loc', | ||
'ocr': 'OCR', | ||
'counting': 'Count', | ||
'hallucination': 'HLN', | ||
'image_retrieval': 'IR', | ||
'threed': '3D', | ||
'visual_captioning': 'VC', | ||
'visual_grounding': 'VG', | ||
'doc_understanding': 'DU', | ||
'action_recognition': 'AR', | ||
'pixel_level_perception': 'PLP', | ||
'image-to-image_translation': 'I2IT', | ||
'relation_reasoning': 'RR', | ||
'intelligence_quotient_test': 'IQT', | ||
'emotion': 'Emo', | ||
'visual_illusion': 'VI', | ||
'meme_understanding': 'MemU', | ||
'visual_prompt_understanding': 'VPU', | ||
'anomaly_detection': 'AND', | ||
'keypoint_detection': 'KD', | ||
'visual_commonsense_reasoning': 'VCR', | ||
'image_evaluation_judgement': 'IEJ', | ||
'multiple_image_analysis': 'MIA', | ||
'cross_image_matching': 'CIM', | ||
'temporal_understanding': 'TU', | ||
'visual_code': 'VP', | ||
'medical_understanding': 'MedU', | ||
'autonomous_driving': 'AUD', | ||
'discipline_knowledge_reasoning': 'DKR', | ||
'embodied_ai': 'EA', | ||
'gui_navigation': 'GN' | ||
} | ||
|
||
# ============================ | ||
# Visual Processing Functions | ||
# ============================ | ||
|
||
def mmt_doc_to_visual(doc): | ||
image = doc["image"].convert("RGB") | ||
return [image] | ||
|
||
# ============================ | ||
# Text Processing Functions | ||
# ============================ | ||
|
||
def mmt_doc_to_text(doc, lmms_eval_specific_kwargs=None): | ||
question_text = "Question: <image>\n" + doc['question'].strip() | ||
|
||
options = [] | ||
for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']: | ||
option_text = doc.get(option) | ||
if option_text and option_text.strip(): | ||
options.append(f"{option}: {option_text.strip()}") | ||
|
||
options_text = "\n".join(options) if options else "" | ||
|
||
formatted_question = f"{question_text}\n{options_text}" | ||
|
||
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") if lmms_eval_specific_kwargs else "" | ||
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") if lmms_eval_specific_kwargs else "" | ||
|
||
formatted_question = f"{pre_prompt}{formatted_question}{post_prompt}" | ||
|
||
return formatted_question | ||
|
||
def mmt_doc_to_choice(doc): | ||
choices = [] | ||
for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']: | ||
if doc.get(option) and doc[option].strip(): | ||
choices.append(option) | ||
return choices | ||
|
||
def mmt_doc_to_target(doc): | ||
return doc["answer"] | ||
|
||
# ============================ | ||
# Result Processing Functions | ||
# ============================ | ||
|
||
def mmt_process_results(doc, results): | ||
|
||
response = results[0].strip() | ||
all_choices = [choice for choice in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] if doc.get(choice)] | ||
pred = parse_multi_choice_response(response, all_choices) #Adapt from MMMU | ||
|
||
gt_ans = doc.get("answer", "").strip() | ||
score = 1.0 if pred == gt_ans else 0.0 | ||
|
||
l2_category = doc.get("l2-category", "unknown") | ||
|
||
accuracy_dict = { | ||
"overall": score, # Overall accuracy | ||
l2_category: score # Accuracy for the specific sub-category | ||
} | ||
|
||
submission_dict = {} | ||
if doc.get("split") == "TEST": | ||
submission_dict = { | ||
doc.get("index", "unknown"): pred # Save using index of the question | ||
} | ||
|
||
return { | ||
"accuracy": accuracy_dict, | ||
"submission": submission_dict | ||
} | ||
|
||
|
||
|
||
# ============================ | ||
# Aggregation Functions | ||
# ============================ | ||
|
||
def mmt_aggregate_results(results): | ||
total_correct = 0 | ||
total_examples = 0 | ||
category_correct = defaultdict(int) | ||
category_total = defaultdict(int) | ||
|
||
for result in results: | ||
# Overall accuracy | ||
total_correct += result['overall'] | ||
total_examples += 1 | ||
|
||
# Sub-category accuracy | ||
for category, score in result.items(): | ||
if category != 'overall': | ||
category_correct[category] += score | ||
category_total[category] += 1 | ||
|
||
overall_accuracy = (total_correct / total_examples) * 100 if total_examples > 0 else 0.0 | ||
|
||
category_accuracy = { | ||
category: (category_correct[category] / category_total[category]) * 100 if category_total[category] > 0 else 0.0 | ||
for category in category_correct | ||
} | ||
|
||
final_results = { | ||
"overall_accuracy": round(overall_accuracy, 5), | ||
"category_accuracy": {category: round(acc, 5) for category, acc in category_accuracy.items()} | ||
} | ||
|
||
print(final_results) | ||
return final_results | ||
|
||
def mmt_test_aggregate_results_for_submission(results, args): | ||
path = generate_submission_file("mmt_test_submission.json", args) | ||
with open(path, "w") as f: | ||
json.dump(results, f) | ||
eval_logger.info(f"Results saved to {path}.") | ||
|
||
################## | ||
# Helper functions adapted from MMMU's utils.py. | ||
################## | ||
|
||
def parse_multi_choice_response(response, all_choices): | ||
""" | ||
Parse the prediction from the generated response. | ||
Return the predicted choice letter e.g., A, B, C, D. | ||
""" | ||
# Clean response of unwanted characters | ||
for char in [",", ".", "!", "?", ";", ":", "'"]: | ||
response = response.strip(char) | ||
response = " " + response + " " # Add space to avoid partial match | ||
|
||
candidates = [] | ||
|
||
# Look for choices with parentheses, e.g., (A) | ||
for choice in all_choices: | ||
if f"({choice})" in response: | ||
candidates.append(choice) | ||
|
||
# Look for simple choices, e.g., A, B, C | ||
if len(candidates) == 0: | ||
for choice in all_choices: | ||
if f" {choice} " in response: | ||
candidates.append(choice) | ||
|
||
# Look for choices with periods, e.g., A., B., C. | ||
if len(candidates) == 0: | ||
for choice in all_choices: | ||
if f"{choice}." in response: | ||
candidates.append(choice) | ||
|
||
# If no candidates, randomly choose one | ||
if len(candidates) == 0: | ||
pred_index = random.choice(all_choices) | ||
elif len(candidates) > 1: | ||
# If more than one candidate, choose the last one found | ||
start_indexes = [response.rfind(f" {can} ") for can in candidates] | ||
pred_index = candidates[np.argmax(start_indexes)] | ||
else: | ||
# If only one candidate, use it | ||
pred_index = candidates[0] | ||
|
||
return pred_index |
Oops, something went wrong.