Skip to content

Commit

Permalink
Add MMT task
Browse files Browse the repository at this point in the history
  • Loading branch information
ngquangtrung57 committed Sep 21, 2024
1 parent 10be8c3 commit 8928a5a
Show file tree
Hide file tree
Showing 9 changed files with 494 additions and 0 deletions.
9 changes: 9 additions & 0 deletions lmms_eval/tasks/mmt/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single character from the given options."
generation_kwargs:
max_new_tokens: 8
metadata:
version: 0.0
task_type: image
4 changes: 4 additions & 0 deletions lmms_eval/tasks/mmt/mmt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: mmt
task:
- mmt_val
- mmt_test
4 changes: 4 additions & 0 deletions lmms_eval/tasks/mmt/mmt_mi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: mmt_mi
task:
- mmt_mi_val
- mmt_mi_test
15 changes: 15 additions & 0 deletions lmms_eval/tasks/mmt/mmt_mi_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset_path: lmms-lab/MMT_MI-Benchmark
dataset_kwargs:
token: True
task: "mmt_mi_test"
test_split: test
doc_to_visual: !function utils_mi.mmt_doc_to_visual
doc_to_text: !function utils_mi.mmt_doc_to_text
doc_to_choice: !function utils_mi.mmt_doc_to_choice
doc_to_target: !function utils_mi.mmt_doc_to_target
process_results: !function utils_mi.mmt_process_results
metric_list:
- metric: submission
aggregation: !function utils_mi.mmt_test_aggregate_results_for_submission
higher_is_better: true
include: _default_template_yaml
15 changes: 15 additions & 0 deletions lmms_eval/tasks/mmt/mmt_mi_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset_path: lmms-lab/MMT_MI-Benchmark
dataset_kwargs:
token: True
task: "mmt_mi_val"
test_split: val
doc_to_visual: !function utils_mi.mmt_doc_to_visual
doc_to_text: !function utils_mi.mmt_doc_to_text
doc_to_choice: !function utils_mi.mmt_doc_to_choice
doc_to_target: !function utils_mi.mmt_doc_to_target
process_results: !function utils_mi.mmt_process_results
metric_list:
- metric: accuracy
aggregation: !function utils_mi.mmt_aggregate_results
higher_is_better: true
include: _default_template_yaml
15 changes: 15 additions & 0 deletions lmms_eval/tasks/mmt/mmt_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset_path: lmms-lab/MMT-Benchmark
dataset_kwargs:
token: True
task: "mmt_test"
test_split: test
doc_to_visual: !function utils.mmt_doc_to_visual
doc_to_text: !function utils.mmt_doc_to_text
doc_to_choice: !function utils.mmt_doc_to_choice
doc_to_target: !function utils.mmt_doc_to_target
process_results: !function utils.mmt_process_results
metric_list:
- metric: submission
aggregation: !function utils.mmt_test_aggregate_results_for_submission
higher_is_better: true
include: _default_template_yaml
15 changes: 15 additions & 0 deletions lmms_eval/tasks/mmt/mmt_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset_path: lmms-lab/MMT-Benchmark
dataset_kwargs:
token: True
task: "mmt_val"
test_split: val
doc_to_visual: !function utils.mmt_doc_to_visual
doc_to_text: !function utils.mmt_doc_to_text
doc_to_choice: !function utils.mmt_doc_to_choice
doc_to_target: !function utils.mmt_doc_to_target
process_results: !function utils.mmt_process_results
metric_list:
- metric: accuracy
aggregation: !function utils.mmt_aggregate_results
higher_is_better: true
include: _default_template_yaml
209 changes: 209 additions & 0 deletions lmms_eval/tasks/mmt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from PIL import Image
from collections import defaultdict
from loguru import logger as eval_logger
import random
import json
import re
import pandas as pd
import numpy as np

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

# All abbreviations for the categories
MMT_abbrs = {
'visual_recognition': 'VR',
'localization': 'Loc',
'ocr': 'OCR',
'counting': 'Count',
'hallucination': 'HLN',
'image_retrieval': 'IR',
'threed': '3D',
'visual_captioning': 'VC',
'visual_grounding': 'VG',
'doc_understanding': 'DU',
'action_recognition': 'AR',
'pixel_level_perception': 'PLP',
'image-to-image_translation': 'I2IT',
'relation_reasoning': 'RR',
'intelligence_quotient_test': 'IQT',
'emotion': 'Emo',
'visual_illusion': 'VI',
'meme_understanding': 'MemU',
'visual_prompt_understanding': 'VPU',
'anomaly_detection': 'AND',
'keypoint_detection': 'KD',
'visual_commonsense_reasoning': 'VCR',
'image_evaluation_judgement': 'IEJ',
'multiple_image_analysis': 'MIA',
'cross_image_matching': 'CIM',
'temporal_understanding': 'TU',
'visual_code': 'VP',
'medical_understanding': 'MedU',
'autonomous_driving': 'AUD',
'discipline_knowledge_reasoning': 'DKR',
'embodied_ai': 'EA',
'gui_navigation': 'GN'
}

# ============================
# Visual Processing Functions
# ============================

def mmt_doc_to_visual(doc):
image = doc["image"].convert("RGB")
return [image]

# ============================
# Text Processing Functions
# ============================

def mmt_doc_to_text(doc, lmms_eval_specific_kwargs=None):
question_text = "Question: <image>\n" + doc['question'].strip()

options = []
for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']:
option_text = doc.get(option)
if option_text and option_text.strip():
options.append(f"{option}: {option_text.strip()}")

options_text = "\n".join(options) if options else ""

formatted_question = f"{question_text}\n{options_text}"

pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") if lmms_eval_specific_kwargs else ""
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") if lmms_eval_specific_kwargs else ""

formatted_question = f"{pre_prompt}{formatted_question}{post_prompt}"

return formatted_question

def mmt_doc_to_choice(doc):
choices = []
for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']:
if doc.get(option) and doc[option].strip():
choices.append(option)
return choices

def mmt_doc_to_target(doc):
return doc["answer"]

# ============================
# Result Processing Functions
# ============================

def mmt_process_results(doc, results):

response = results[0].strip()
all_choices = [choice for choice in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] if doc.get(choice)]
pred = parse_multi_choice_response(response, all_choices) #Adapt from MMMU

gt_ans = doc.get("answer", "").strip()
score = 1.0 if pred == gt_ans else 0.0

l2_category = doc.get("l2-category", "unknown")

accuracy_dict = {
"overall": score, # Overall accuracy
l2_category: score # Accuracy for the specific sub-category
}

submission_dict = {}
if doc.get("split") == "TEST":
submission_dict = {
doc.get("index", "unknown"): pred # Save using index of the question
}

return {
"accuracy": accuracy_dict,
"submission": submission_dict
}



# ============================
# Aggregation Functions
# ============================

def mmt_aggregate_results(results):
total_correct = 0
total_examples = 0
category_correct = defaultdict(int)
category_total = defaultdict(int)

for result in results:
# Overall accuracy
total_correct += result['overall']
total_examples += 1

# Sub-category accuracy
for category, score in result.items():
if category != 'overall':
category_correct[category] += score
category_total[category] += 1

overall_accuracy = (total_correct / total_examples) * 100 if total_examples > 0 else 0.0

category_accuracy = {
category: (category_correct[category] / category_total[category]) * 100 if category_total[category] > 0 else 0.0
for category in category_correct
}

final_results = {
"overall_accuracy": round(overall_accuracy, 5),
"category_accuracy": {category: round(acc, 5) for category, acc in category_accuracy.items()}
}

print(final_results)
return final_results

def mmt_test_aggregate_results_for_submission(results, args):
path = generate_submission_file("mmt_test_submission.json", args)
with open(path, "w") as f:
json.dump(results, f)
eval_logger.info(f"Results saved to {path}.")

##################
# Helper functions adapted from MMMU's utils.py.
##################

def parse_multi_choice_response(response, all_choices):
"""
Parse the prediction from the generated response.
Return the predicted choice letter e.g., A, B, C, D.
"""
# Clean response of unwanted characters
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # Add space to avoid partial match

candidates = []

# Look for choices with parentheses, e.g., (A)
for choice in all_choices:
if f"({choice})" in response:
candidates.append(choice)

# Look for simple choices, e.g., A, B, C
if len(candidates) == 0:
for choice in all_choices:
if f" {choice} " in response:
candidates.append(choice)

# Look for choices with periods, e.g., A., B., C.
if len(candidates) == 0:
for choice in all_choices:
if f"{choice}." in response:
candidates.append(choice)

# If no candidates, randomly choose one
if len(candidates) == 0:
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
# If more than one candidate, choose the last one found
start_indexes = [response.rfind(f" {can} ") for can in candidates]
pred_index = candidates[np.argmax(start_indexes)]
else:
# If only one candidate, use it
pred_index = candidates[0]

return pred_index
Loading

0 comments on commit 8928a5a

Please sign in to comment.