-
Notifications
You must be signed in to change notification settings - Fork 193
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #150 from ys-zong/main
Add a new benchmark: MIRB
- Loading branch information
Showing
2 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
|
||
dataset_path: VLLMs/MIRB-hf | ||
task: mirb | ||
dataset_kwargs: | ||
token: True | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.mirb_doc_to_visual | ||
doc_to_text: !function utils.mirb_doc_to_text | ||
doc_to_target: !function utils.mirb_doc_to_target | ||
process_results: !function utils.mirb_process_results | ||
|
||
model_specific_prompt_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "" | ||
|
||
|
||
generation_kwargs: | ||
max_new_tokens: 64 | ||
temperature: 0 | ||
do_sample: False | ||
|
||
metric_list: | ||
- metric: mirb_score | ||
aggregation: !function utils.mirb_aggregation | ||
higher_is_better: true | ||
|
||
metadata: | ||
- version: 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
|
||
from lmms_eval.filters.extraction import ExtendedRegexFilter | ||
from lmms_eval.filters.transformation import MapFilter | ||
import re | ||
import numpy as np | ||
|
||
|
||
import logging | ||
eval_logger = logging.getLogger("lmms-eval") | ||
|
||
def get_task_instruction(dataset): | ||
if dataset in ['analogy', 'attribute', 'plot_code', 'visual_chain', 'sightseeing']: | ||
instr = 'Answer with a single word.' | ||
elif dataset in ['codeu', 'food', 'image_jigsaw']: | ||
instr = 'Answer with the option symbol.' | ||
elif dataset in ['arxiv']: | ||
instr = 'Answer with the paper title.' | ||
elif dataset in ['count']: | ||
instr = 'Answer with a single number.' | ||
elif dataset in ['3d_scene']: | ||
instr = 'The following images are different views of the same 3D scene. Answer with a single number.' | ||
|
||
return instr | ||
|
||
def mirb_doc_to_text(doc, model_specific_prompt_kwargs=None): | ||
subset, question = doc['subset'], doc["questions"] | ||
task_instruction = get_task_instruction(subset) | ||
post_prompt = model_specific_prompt_kwargs["post_prompt"] | ||
pre_prompt = model_specific_prompt_kwargs["pre_prompt"] | ||
return f"{pre_prompt}{task_instruction}{question}{post_prompt}" | ||
|
||
def mirb_doc_to_visual(doc): | ||
image_list = [image.convert("RGB") for image in doc["image_list"]] | ||
return image_list | ||
|
||
def mirb_doc_to_target(doc): | ||
return doc["answers"] | ||
|
||
|
||
def extract_numbers(string): | ||
""" | ||
Exact all forms of numbers from a string with regex. | ||
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L100 | ||
""" | ||
# Pattern for numbers with commas | ||
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" | ||
# Pattern for scientific notation | ||
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" | ||
# Pattern for simple numbers without commas | ||
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" | ||
|
||
# Extract numbers with commas | ||
numbers_with_commas = re.findall(pattern_commas, string) | ||
# Extract numbers in scientific notation | ||
numbers_scientific = re.findall(pattern_scientific, string) | ||
# Extract simple numbers without commas | ||
numbers_simple = re.findall(pattern_simple, string) | ||
|
||
# Combine all extracted numbersz | ||
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple | ||
return all_numbers | ||
|
||
def check_is_number(string): | ||
""" | ||
Check if the given string a number. | ||
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L65 | ||
""" | ||
try: | ||
float(string.replace(",", "")) | ||
return True | ||
except ValueError: | ||
# check if there's comma inside | ||
return False | ||
|
||
def normalize_str(string): | ||
""" | ||
Normalize the str to lower case and make them float numbers if possible. | ||
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L76 | ||
""" | ||
# check if characters in the string | ||
|
||
# if number, numerize it. | ||
string = string.strip() | ||
|
||
is_number = check_is_number(string) | ||
|
||
if is_number: | ||
string = string.replace(",", "") | ||
string = float(string) | ||
# leave 2 decimal | ||
string = round(string, 2) | ||
return [string] | ||
else: # it's likely to be a string | ||
# lower it | ||
string = string.lower() | ||
if len(string) == 1: | ||
return [" " + string, string + " "] # avoid trivial matches | ||
return [string] | ||
|
||
def parse_multi_choice_response(response): | ||
# here, we assume we have a list, in which each element is | ||
# a list of model responses for some particular input/target pair. | ||
# so we process each of these (same input/target response sets) | ||
# independently (and keep them a list.) | ||
filtered_resps = [] | ||
option_letter_regex = re.compile(r"^\s*([A-Z])\.") | ||
match = option_letter_regex.match(response) | ||
if match: | ||
# If a match is found, append the matched letter | ||
filtered_resps = match.group(1) | ||
else: | ||
# If no match, return the original response | ||
filtered_resps = response | ||
return filtered_resps | ||
|
||
|
||
def parse_open_response(response): | ||
""" | ||
Parse the prediction from the generated response. | ||
Return a list of predicted strings or numbers. | ||
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L122 | ||
""" | ||
|
||
# content = content.strip("\n").strip(".").strip(" ") | ||
def get_key_subresponses(response): | ||
key_responses = [] | ||
response = response.strip().strip(".").lower() | ||
sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response) | ||
indicators_of_keys = [ | ||
"could be ", | ||
"so ", | ||
"is ", | ||
"thus ", | ||
"therefore ", | ||
"final ", | ||
"answer ", | ||
"result ", | ||
] | ||
key_responses = [] | ||
for index, resp in enumerate(sub_responses): | ||
# if last one, accept it's an equation (the entire response can be just one sentence with equation) | ||
if index == len(sub_responses) - 1: | ||
indicators_of_keys.extend(["="]) | ||
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) | ||
for indicator in indicators_of_keys: | ||
if indicator in resp: | ||
if not shortest_key_response: | ||
shortest_key_response = resp.split(indicator)[-1].strip() | ||
else: | ||
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): | ||
shortest_key_response = resp.split(indicator)[-1].strip() | ||
# key_responses.append(resp.split(indicator)[1].strip()) | ||
|
||
if shortest_key_response: | ||
# and it's not trivial | ||
if shortest_key_response.strip() not in [ | ||
":", | ||
",", | ||
".", | ||
"!", | ||
"?", | ||
";", | ||
":", | ||
"'", | ||
]: | ||
key_responses.append(shortest_key_response) | ||
if len(key_responses) == 0: # did not found any | ||
return [response] | ||
return key_responses | ||
|
||
# pdb.set_trace() | ||
key_responses = get_key_subresponses(response) | ||
|
||
pred_list = key_responses.copy() # keep the original string response | ||
for resp in key_responses: | ||
pred_list.extend(extract_numbers(resp)) | ||
|
||
tmp_pred_list = [] | ||
for i in range(len(pred_list)): | ||
tmp_pred_list.extend(normalize_str(pred_list[i])) | ||
pred_list = tmp_pred_list | ||
|
||
# remove duplicates | ||
pred_list = list(set(pred_list)) | ||
|
||
return pred_list | ||
|
||
def mirb_process_results(doc, results): | ||
pred = results[0] | ||
subset, answer = doc["subset"], doc["answers"] | ||
if answer in ['A', 'B', 'C', 'D', 'E']: # MCQ tasks | ||
parsed_pred = parse_multi_choice_response(pred) | ||
else: | ||
parsed_pred = parse_open_response(pred) | ||
task_type = doc["subset"] | ||
data_dict = {"question_id": doc["question_id"], "subset": task_type, "pred_answer": parsed_pred, "answers": doc["answers"]} | ||
return {f"mirb_score": data_dict} | ||
|
||
def eval_multi_choice(gold_i, pred_i): | ||
""" | ||
Evaluate a multiple choice instance. | ||
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L175 | ||
""" | ||
correct = False | ||
# only they are exactly the same, we consider it as correct | ||
if isinstance(gold_i, list): | ||
for answer in gold_i: | ||
if answer == pred_i: | ||
correct = True | ||
break | ||
else: # gold_i is a string | ||
if gold_i == pred_i: | ||
correct = True | ||
return correct | ||
|
||
|
||
def eval_open(gold_i, pred_i): | ||
""" | ||
Evaluate an open question instance | ||
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L191 | ||
""" | ||
correct = False | ||
if isinstance(gold_i, list): | ||
# use float to avoid trivial matches | ||
norm_answers = [] | ||
for answer in gold_i: | ||
norm_answers.extend(normalize_str(answer)) | ||
else: | ||
norm_answers = normalize_str(gold_i) | ||
for pred in pred_i: # pred is already normalized in parse response phase | ||
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i | ||
for norm_ans in norm_answers: | ||
# only see if the string answer in the string pred | ||
if isinstance(norm_ans, str) and norm_ans in pred: | ||
if not correct: | ||
correct = True | ||
break | ||
else: # it's a float number | ||
if pred in norm_answers: | ||
if not correct: | ||
correct = True | ||
break | ||
return correct | ||
|
||
def mirb_aggregation(results): | ||
task_num = {} | ||
score = 0 | ||
task_score = {} | ||
for result in results: | ||
if result["subset"] not in task_score: | ||
task_score[result["subset"]] = 0 | ||
task_num[result["subset"]] = 0 | ||
|
||
if result['answers'] in ['A', 'B', 'C', 'D', 'E']: # MCQ tasks | ||
correct = eval_multi_choice(result["answers"], result["pred_answer"]) | ||
task_score[result["subset"]] += correct | ||
score += correct | ||
else: | ||
correct = eval_open(result["answers"], result["pred_answer"]) | ||
task_score[result["subset"]] += correct | ||
score += correct | ||
task_num[result["subset"]] += 1 | ||
avg_score = score / len(results) | ||
|
||
task_score = {k : v / task_num[k] for k,v in task_score.items()} | ||
|
||
print("Performances for different subsets:") | ||
print("=" * 50) | ||
for k, v in task_score.items(): | ||
print(f"{k} : {v:.2f}") | ||
print("=" * 50) | ||
|
||
# print across evaluation dimension | ||
groups = { | ||
"Knowledge": ["food", "sightseeing"], | ||
"Reasoning": ["codeu", "plot_code", "analogy", "3d_scene"], | ||
"Perception": ["image_jigsaw", "count", "attribute"], | ||
"Multi-Hop": ["visual_chain", "arxiv"] | ||
} | ||
|
||
# Compute the averages for each group | ||
averages_dict = compute_averages_from_task_scores(task_score, groups) | ||
|
||
# Print group performances | ||
print("Performances for different dimensions:") | ||
print("=" * 50) | ||
for group, score in averages_dict.items(): | ||
print(f"{group} : {score:.2f}") | ||
print("=" * 50) | ||
|
||
return avg_score | ||
|
||
|
||
# Compute averages for each group based on task scores | ||
def compute_averages_from_task_scores(task_score, groups): | ||
averages = {} | ||
for group, features in groups.items(): | ||
values = [task_score[feature] for feature in features] | ||
averages[group] = np.mean(values) | ||
return averages |