Skip to content

Commit

Permalink
Add MMT
Browse files Browse the repository at this point in the history
  • Loading branch information
ngquangtrung57 committed Sep 21, 2024
1 parent 8928a5a commit 100de1a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 63 deletions.
77 changes: 44 additions & 33 deletions lmms_eval/tasks/mmt/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from PIL import Image
from collections import defaultdict
from loguru import logger as eval_logger
import random
import json
import re
import pandas as pd
import numpy as np

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

# All abbreviations for the categories
'''
MMT_abbrs = {
'visual_recognition': 'VR',
'localization': 'Loc',
Expand Down Expand Up @@ -44,19 +42,20 @@
'embodied_ai': 'EA',
'gui_navigation': 'GN'
}

'''
# ============================
# Visual Processing Functions
# ============================

def mmt_doc_to_visual(doc):
image = doc["image"].convert("RGB")
return [image]

def mmt_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["image"]]

# ============================
# Text Processing Functions
# ============================


def mmt_doc_to_text(doc, lmms_eval_specific_kwargs=None):
question_text = "Question: <image>\n" + doc['question'].strip()

Expand All @@ -70,56 +69,62 @@ def mmt_doc_to_text(doc, lmms_eval_specific_kwargs=None):

formatted_question = f"{question_text}\n{options_text}"

pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") if lmms_eval_specific_kwargs else ""
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") if lmms_eval_specific_kwargs else ""
pre_prompt = (
lmms_eval_specific_kwargs.get("pre_prompt", "")
if lmms_eval_specific_kwargs else ""
)

post_prompt = (
lmms_eval_specific_kwargs.get("post_prompt", "")
if lmms_eval_specific_kwargs else ""
)
formatted_question = f"{pre_prompt}{formatted_question}{post_prompt}"

return formatted_question


def mmt_doc_to_choice(doc):
choices = []
for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']:
if doc.get(option) and doc[option].strip():
choices.append(option)
return choices


def mmt_doc_to_target(doc):
return doc["answer"]

# ============================
# Result Processing Functions
# ============================


def mmt_process_results(doc, results):

response = results[0].strip()
all_choices = [choice for choice in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] if doc.get(choice)]
pred = parse_multi_choice_response(response, all_choices) #Adapt from MMMU

response = results[0].strip()
all_choices = [
choice for choice in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
if doc.get(choice)
]
pred = parse_multi_choice_response(response, all_choices) # AdaptfromMMMU
gt_ans = doc.get("answer", "").strip()
score = 1.0 if pred == gt_ans else 0.0

l2_category = doc.get("l2-category", "unknown")
l2_category = doc.get("l2-category", "unknown")

accuracy_dict = {
"overall": score, # Overall accuracy
l2_category: score # Accuracy for the specific sub-category
}

submission_dict = {}
if doc.get("split") == "TEST":
submission_dict = {
doc.get("index", "unknown"): pred # Save using index of the question
doc.get("index", "unknown"): pred # Save using index
}

return {
"accuracy": accuracy_dict,
"submission": submission_dict
}



# ============================
# Aggregation Functions
# ============================
Expand All @@ -129,33 +134,39 @@ def mmt_aggregate_results(results):
total_examples = 0
category_correct = defaultdict(int)
category_total = defaultdict(int)

for result in results:
# Overall accuracy
total_correct += result['overall']
total_examples += 1

# Sub-category accuracy
for category, score in result.items():
if category != 'overall':
if category != 'overall':
category_correct[category] += score
category_total[category] += 1

overall_accuracy = (total_correct / total_examples) * 100 if total_examples > 0 else 0.0

overall_accuracy = (
(total_correct / total_examples) * 100
if total_examples > 0
else 0.0
)
category_accuracy = {
category: (category_correct[category] / category_total[category]) * 100 if category_total[category] > 0 else 0.0
category: (
(category_correct[category] / category_total[category]) * 100
if category_total[category] > 0
else 0.0
)
for category in category_correct
}

final_results = {
"overall_accuracy": round(overall_accuracy, 5),
"category_accuracy": {category: round(acc, 5) for category, acc in category_accuracy.items()}
"category_accuracy": {
category: round(acc, 5)
for category, acc in category_accuracy.items()
}
}

print(final_results)
print(final_results)
return final_results


def mmt_test_aggregate_results_for_submission(results, args):
path = generate_submission_file("mmt_test_submission.json", args)
with open(path, "w") as f:
Expand All @@ -166,6 +177,7 @@ def mmt_test_aggregate_results_for_submission(results, args):
# Helper functions adapted from MMMU's utils.py.
##################


def parse_multi_choice_response(response, all_choices):
"""
Parse the prediction from the generated response.
Expand All @@ -177,7 +189,6 @@ def parse_multi_choice_response(response, all_choices):
response = " " + response + " " # Add space to avoid partial match

candidates = []

# Look for choices with parentheses, e.g., (A)
for choice in all_choices:
if f"({choice})" in response:
Expand Down Expand Up @@ -206,4 +217,4 @@ def parse_multi_choice_response(response, all_choices):
# If only one candidate, use it
pred_index = candidates[0]

return pred_index
return pred_index
72 changes: 42 additions & 30 deletions lmms_eval/tasks/mmt/utils_mi.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from PIL import Image
from collections import defaultdict
from loguru import logger as eval_logger
import random
import json
import re
import pandas as pd
import numpy as np

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

# All abbreviations for the categories
'''
MMT_abbrs = {
'visual_recognition': 'VR',
'localization': 'Loc',
Expand Down Expand Up @@ -44,18 +42,20 @@
'embodied_ai': 'EA',
'gui_navigation': 'GN'
}

'''
# ============================
# Visual Processing Functions
# ============================


def mmt_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["image"]]

# ============================
# Text Processing Functions
# ============================


def mmt_doc_to_text(doc, lmms_eval_specific_kwargs=None):
question_text = "Question: <image>\n" + doc['question'].strip()

Expand All @@ -69,56 +69,62 @@ def mmt_doc_to_text(doc, lmms_eval_specific_kwargs=None):

formatted_question = f"{question_text}\n{options_text}"

pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") if lmms_eval_specific_kwargs else ""
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") if lmms_eval_specific_kwargs else ""
pre_prompt = (
lmms_eval_specific_kwargs.get("pre_prompt", "")
if lmms_eval_specific_kwargs else ""
)

post_prompt = (
lmms_eval_specific_kwargs.get("post_prompt", "")
if lmms_eval_specific_kwargs else ""
)
formatted_question = f"{pre_prompt}{formatted_question}{post_prompt}"

return formatted_question


def mmt_doc_to_choice(doc):
choices = []
for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']:
if doc.get(option) and doc[option].strip():
choices.append(option)
return choices


def mmt_doc_to_target(doc):
return doc["answer"]

# ============================
# Result Processing Functions
# ============================


def mmt_process_results(doc, results):

response = results[0].strip()
all_choices = [choice for choice in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] if doc.get(choice)]
pred = parse_multi_choice_response(response, all_choices) #Adapt from MMMU

response = results[0].strip()
all_choices = [
choice for choice in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
if doc.get(choice)
]
pred = parse_multi_choice_response(response, all_choices) # AdaptfromMMMU
gt_ans = doc.get("answer", "").strip()
score = 1.0 if pred == gt_ans else 0.0

l2_category = doc.get("l2-category", "unknown")
l2_category = doc.get("l2-category", "unknown")

accuracy_dict = {
"overall": score, # Overall accuracy
l2_category: score # Accuracy for the specific sub-category
}

submission_dict = {}
if doc.get("split") == "TEST":
submission_dict = {
doc.get("index", "unknown"): pred # Save using index of the question
doc.get("index", "unknown"): pred # Save using index
}

return {
"accuracy": accuracy_dict,
"submission": submission_dict
}



# ============================
# Aggregation Functions
# ============================
Expand All @@ -128,33 +134,39 @@ def mmt_aggregate_results(results):
total_examples = 0
category_correct = defaultdict(int)
category_total = defaultdict(int)

for result in results:
# Overall accuracy
total_correct += result['overall']
total_examples += 1

# Sub-category accuracy
for category, score in result.items():
if category != 'overall':
if category != 'overall':
category_correct[category] += score
category_total[category] += 1

overall_accuracy = (total_correct / total_examples) * 100 if total_examples > 0 else 0.0

overall_accuracy = (
(total_correct / total_examples) * 100
if total_examples > 0
else 0.0
)
category_accuracy = {
category: (category_correct[category] / category_total[category]) * 100 if category_total[category] > 0 else 0.0
category: (
(category_correct[category] / category_total[category]) * 100
if category_total[category] > 0
else 0.0
)
for category in category_correct
}

final_results = {
"overall_accuracy": round(overall_accuracy, 5),
"category_accuracy": {category: round(acc, 5) for category, acc in category_accuracy.items()}
"category_accuracy": {
category: round(acc, 5)
for category, acc in category_accuracy.items()
}
}

print(final_results)
print(final_results)
return final_results


def mmt_test_aggregate_results_for_submission(results, args):
path = generate_submission_file("mmt_test_submission.json", args)
with open(path, "w") as f:
Expand All @@ -165,6 +177,7 @@ def mmt_test_aggregate_results_for_submission(results, args):
# Helper functions adapted from MMMU's utils.py.
##################


def parse_multi_choice_response(response, all_choices):
"""
Parse the prediction from the generated response.
Expand All @@ -176,7 +189,6 @@ def parse_multi_choice_response(response, all_choices):
response = " " + response + " " # Add space to avoid partial match

candidates = []

# Look for choices with parentheses, e.g., (A)
for choice in all_choices:
if f"({choice})" in response:
Expand Down Expand Up @@ -205,4 +217,4 @@ def parse_multi_choice_response(response, all_choices):
# If only one candidate, use it
pred_index = candidates[0]

return pred_index
return pred_index

0 comments on commit 100de1a

Please sign in to comment.