From ea14cd4b361f4c95b3665cbdb95bc51754090eb5 Mon Sep 17 00:00:00 2001 From: teowu Date: Fri, 14 Jun 2024 15:01:52 +0000 Subject: [PATCH] Add qbench, qbench2, abench; fix phi3v as its current implementation does not support multi-image --- lmms_eval/models/phi3v.py | 16 +- lmms_eval/tasks/qbench/abench_dev.yaml | 22 +++ lmms_eval/tasks/qbench/qbench2_dev.yaml | 22 +++ lmms_eval/tasks/qbench/qbench_dev.yaml | 22 +++ lmms_eval/tasks/qbench/qbenchs_dev.yaml | 5 + lmms_eval/tasks/qbench/utils.py | 247 ++++++++++++++++++++++++ 6 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 lmms_eval/tasks/qbench/abench_dev.yaml create mode 100644 lmms_eval/tasks/qbench/qbench2_dev.yaml create mode 100644 lmms_eval/tasks/qbench/qbench_dev.yaml create mode 100644 lmms_eval/tasks/qbench/qbenchs_dev.yaml create mode 100644 lmms_eval/tasks/qbench/utils.py diff --git a/lmms_eval/models/phi3v.py b/lmms_eval/models/phi3v.py index b34881ea..5e940169 100644 --- a/lmms_eval/models/phi3v.py +++ b/lmms_eval/models/phi3v.py @@ -185,9 +185,16 @@ def _collate(x): contexts = list(contexts) for i in range(len(contexts)): if "" in contexts[i]: - query = contexts[i].replace("", "<|image_1|>") + query = "" + contexts[i] + img_placeholder_count = 1 + while "" in query: + query = query.replace("", f"<|image_{img_placeholder_count}|>", 1) + img_placeholder_count += 1 else: - query = f"<|image_1|>\n{contexts[i]}" + query = "" + for placeholder_id in range(len(visuals)): + query += f"<|image_{placeholder_id+1}|>\n" + query += contexts[i] messages = [ {"role": "user", "content": query} ] @@ -196,12 +203,11 @@ def _collate(x): tokenize=False, add_generation_prompt=True) assert len(contexts) == 1 - # We always pass a single image given that the model only accepts one image (as of 5/21/24). + # context = contexts[0] - pil_image = visuals[0] input_ids = self._processor( text=context, - images=[pil_image], + images=visuals, return_tensors="pt").to(self._device, self.model.dtype) # Setting default parameters. if "max_new_tokens" not in gen_kwargs: diff --git a/lmms_eval/tasks/qbench/abench_dev.yaml b/lmms_eval/tasks/qbench/abench_dev.yaml new file mode 100644 index 00000000..8b44b5f7 --- /dev/null +++ b/lmms_eval/tasks/qbench/abench_dev.yaml @@ -0,0 +1,22 @@ +dataset_path: q-future/A-Bench-HF +task: "abench_dev" +test_split: dev +output_type: generate_until +doc_to_visual: !function utils.q_bench_doc_to_visual +doc_to_text: !function utils.q_bench_doc_to_text +doc_to_target: "correct_choice" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +process_results: !function utils.a_bench_process_results +metric_list: + - metric: abench_acc + aggregation: !function utils.a_bench_aggregate_results + higher_is_better: true + +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "Answer with the option's letter from the given choices directly.\n" + \ No newline at end of file diff --git a/lmms_eval/tasks/qbench/qbench2_dev.yaml b/lmms_eval/tasks/qbench/qbench2_dev.yaml new file mode 100644 index 00000000..7412cb25 --- /dev/null +++ b/lmms_eval/tasks/qbench/qbench2_dev.yaml @@ -0,0 +1,22 @@ +dataset_path: q-future/Q-Bench2-HF +task: "qbench2_dev" +test_split: dev +output_type: generate_until +doc_to_visual: !function utils.q_bench_doc_to_visual +doc_to_text: !function utils.q_bench_doc_to_text +doc_to_target: "correct_choice" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +process_results: !function utils.q_bench_process_results +metric_list: + - metric: qbench_acc + aggregation: !function utils.q_bench_aggregate_results + higher_is_better: true + +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "Answer with the option's letter from the given choices directly.\n" + \ No newline at end of file diff --git a/lmms_eval/tasks/qbench/qbench_dev.yaml b/lmms_eval/tasks/qbench/qbench_dev.yaml new file mode 100644 index 00000000..be901333 --- /dev/null +++ b/lmms_eval/tasks/qbench/qbench_dev.yaml @@ -0,0 +1,22 @@ +dataset_path: q-future/Q-Bench-HF +task: "qbench_dev" +test_split: dev +output_type: generate_until +doc_to_visual: !function utils.q_bench_doc_to_visual +doc_to_text: !function utils.q_bench_doc_to_text +doc_to_target: "correct_choice" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +process_results: !function utils.q_bench_process_results +metric_list: + - metric: qbench_acc + aggregation: !function utils.q_bench_aggregate_results + higher_is_better: true + +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "Answer with the option's letter from the given choices directly.\n" + \ No newline at end of file diff --git a/lmms_eval/tasks/qbench/qbenchs_dev.yaml b/lmms_eval/tasks/qbench/qbenchs_dev.yaml new file mode 100644 index 00000000..3f7fc914 --- /dev/null +++ b/lmms_eval/tasks/qbench/qbenchs_dev.yaml @@ -0,0 +1,5 @@ +group: qbenchs_dev +task: +- qbench_dev +- qbench2_dev +- abench_dev diff --git a/lmms_eval/tasks/qbench/utils.py b/lmms_eval/tasks/qbench/utils.py new file mode 100644 index 00000000..8b69c9c4 --- /dev/null +++ b/lmms_eval/tasks/qbench/utils.py @@ -0,0 +1,247 @@ +import json +import logging +import re +from collections import Counter, defaultdict +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file + + +def q_bench_doc_to_text(doc, model_specific_prompt_kwargs): + candidates = [] + for i in range(4): + candidate = doc.get(f"option{i}") + if candidate != "N/A": + candidates.append(candidate) + + question = doc["question"] + "\n" + "\n".join([". ".join([chr(ord("A")+i), candidate]) for i, candidate in enumerate(candidates)]) + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + post_prompt = model_specific_prompt_kwargs["post_prompt"] + return f"{pre_prompt}{question}\n{post_prompt}" + + +def q_bench_doc_to_visual(doc): + if "image2" not in doc: + return [doc["image"].convert("RGB")] + else: + return [doc["image1"].convert("RGB"), doc["image2"].convert("RGB")] + + +def get_multi_choice_info(options): + """ + Given the list of options for multiple choice question + Return the index2ans and all_choices + https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/data_utils.py#L54 + """ + + start_chr = "A" + all_choices = [] + index2ans = {} + for i, option in enumerate(options): + index2ans[chr(ord(start_chr) + i)] = option + all_choices.append(chr(ord(start_chr) + i)) + + return index2ans, all_choices + + +def parse_multi_choice_response(response, all_choices, index2ans): + """ + Parse the prediction from the generated response. + Return the predicted index e.g., A, B, C, D. + https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10 + """ + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + candidates = [] + for choice in all_choices: # e.g., (A) (B) (C) (D) + if f"({choice})" in response: + candidates.append(choice) + ans_with_brack = True + + if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if f"{choice} " in response: + candidates.append(choice) + + if len(candidates) == 0: + for choice in all_choices: # e.g., A. B. C. D. + if f"{choice}." in response: + candidates.append(choice) + + # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, randomly choose one. + pred_index = random.choice(all_choices) + elif len(candidates) > 1: + start_indexes = [] + if index_ans: + if ans_with_brack: + for can in candidates: + index = response.rfind(f"({can})") + start_indexes.append(index) # -1 will be ignored anyway + # start_indexes = [generated_response.index(f'({can})') for can in candidates] + else: + for can in candidates: + index = response.rfind(f" {can} ") + start_indexes.append(index) + else: + for can in candidates: + index = response.lower().rfind(index2ans[can].lower()) + start_indexes.append(index) + # get the last one + pred_index = candidates[np.argmax(start_indexes)] + else: # if only one candidate, use it. + pred_index = candidates[0] + + return pred_index + + +def evaluate_q_bench(samples): + pred_correct = 0 + judge_dict = dict() + for sample in samples: + gold_i = sample["answer"] + pred_i = sample["parsed_pred"] + correct = eval_multi_choice(gold_i, pred_i) + + if correct: + judge_dict[sample["id"]] = "Correct" + pred_correct += 1 + else: + judge_dict[sample["id"]] = "Wrong" + + if len(samples) == 0: + return {"acc": 0} + return judge_dict, {"acc": pred_correct / len(samples)} + +def eval_multi_choice(gold_i, pred_i): + correct = False + # only they are exactly the same, we consider it as correct + if isinstance(gold_i, list): + for answer in gold_i: + if answer == pred_i: + correct = True + break + else: # gold_i is a string + if gold_i == pred_i: + correct = True + return correct + +def calculate_ins_level_acc(results): + """Calculate the instruction level accuracy for given Subject results + https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L246 + """ + acc = 0 + ins_num = 0 + for cat_results in results.values(): + acc += cat_results["acc"] * cat_results["num_example"] + ins_num += cat_results["num_example"] + if ins_num == 0: + return 0 + return acc / ins_num + + +def q_bench_process_results(doc, results): + pred = results[0] + all_choices = [] + index2ans = {} + for i in range(4): + option = doc.get(f"option{i}") + if option == "N/A": + break + index2ans[chr(ord("A") + i)] = option + all_choices.append(chr(ord("A") + i)) + + parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) + id = doc["id"] + qbench_acc = {"id": id, "question_concern": doc["question_concern"], "question_type": doc["question_type"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred} + return { + "qbench_acc": qbench_acc, + "submission": { + id: pred, + }, + } + + +concern_list = ["Global Distortion", "Global Others", "Local Distortion", "Local Others"] +question_list = ["Yes/No", "How", "What"] + +def q_bench_aggregate_results(results): + evaluation_result = {} + subset_to_eval_samples = defaultdict(list) + for result in results: + subset_to_eval_samples[concern_list[result["question_concern"]]].append(result) + subset_to_eval_samples[question_list[result["question_type"]]].append(result) + for subset, sub_eval_samples in subset_to_eval_samples.items(): + judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples) + metric_dict.update({"num_example": len(sub_eval_samples)}) + evaluation_result[subset] = metric_dict + printable_results = {} + + for cat_name, cat_results in evaluation_result.items(): + printable_results[cat_name] = { + "num": int(cat_results["num_example"]), + "acc": round(cat_results["acc"], 5), + } + all_ins_acc = calculate_ins_level_acc(evaluation_result) + printable_results["Overall"] = { + "num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]), + "acc": round(all_ins_acc, 5), + } + print(printable_results) + return printable_results["Overall"]["acc"] + +def a_bench_process_results(doc, results): + pred = results[0] + all_choices = [] + index2ans = {} + for i in range(4): + option = doc.get(f"option{i}") + if option == "N/A": + break + index2ans[chr(ord("A") + i)] = option + all_choices.append(chr(ord("A") + i)) + + parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) + id = doc["id"] + abench_acc = {"id": id, "category": doc["category"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred} + return { + "abench_acc": abench_acc, + "submission": { + id: pred, + }, + } + + + +def a_bench_aggregate_results(results): + evaluation_result = {} + subset_to_eval_samples = defaultdict(list) + for result in results: + subset_to_eval_samples[result["category"]].append(result) + for subset, sub_eval_samples in subset_to_eval_samples.items(): + judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples) + metric_dict.update({"num_example": len(sub_eval_samples)}) + evaluation_result[subset] = metric_dict + printable_results = {} + + for cat_name, cat_results in evaluation_result.items(): + printable_results[cat_name] = { + "num": int(cat_results["num_example"]), + "acc": round(cat_results["acc"], 5), + } + all_ins_acc = calculate_ins_level_acc(evaluation_result) + printable_results["Overall"] = { + "num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]), + "acc": round(all_ins_acc, 5), + } + print(printable_results) + return printable_results["Overall"]["acc"] +