From 1cbc746a2a32b975559067f1321a16dbda54c8be Mon Sep 17 00:00:00 2001 From: XinrunDu <154438029+XinrunDu@users.noreply.github.com> Date: Tue, 13 Feb 2024 18:50:23 +0800 Subject: [PATCH] add stvqa and multidocvqa (#46) --- .../tasks/multidocvqa/multidocvqa_test.yaml | 20 ++++ .../tasks/multidocvqa/multidocvqa_val.yaml | 23 ++++ lmms_eval/tasks/multidocvqa/utils.py | 108 ++++++++++++++++++ lmms_eval/tasks/stvqa/stvqa.yaml | 20 ++++ lmms_eval/tasks/stvqa/utils.py | 21 ++++ 5 files changed, 192 insertions(+) create mode 100644 lmms_eval/tasks/multidocvqa/multidocvqa_test.yaml create mode 100644 lmms_eval/tasks/multidocvqa/multidocvqa_val.yaml create mode 100644 lmms_eval/tasks/multidocvqa/utils.py create mode 100644 lmms_eval/tasks/stvqa/stvqa.yaml create mode 100644 lmms_eval/tasks/stvqa/utils.py diff --git a/lmms_eval/tasks/multidocvqa/multidocvqa_test.yaml b/lmms_eval/tasks/multidocvqa/multidocvqa_test.yaml new file mode 100644 index 00000000..d33c5a50 --- /dev/null +++ b/lmms_eval/tasks/multidocvqa/multidocvqa_test.yaml @@ -0,0 +1,20 @@ +dataset_path: lmms-lab/MP-DocVQA +task: "multidocvqa_test" +test_split: test +output_type: generate_until +doc_to_visual: !function utils.multidocvqa_doc_to_visual +doc_to_text: !function utils.multidocvqa_doc_to_text +doc_to_target: "answers" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +process_results: !function utils.multidocvqa_process_test_results_for_submission +metric_list: + - metric: submission + aggregation: !function utils.multidocvqa_test_aggregate_results_for_submission +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "\nAnswer the question using a single word or phrase." + \ No newline at end of file diff --git a/lmms_eval/tasks/multidocvqa/multidocvqa_val.yaml b/lmms_eval/tasks/multidocvqa/multidocvqa_val.yaml new file mode 100644 index 00000000..b4f5238b --- /dev/null +++ b/lmms_eval/tasks/multidocvqa/multidocvqa_val.yaml @@ -0,0 +1,23 @@ +dataset_path: lmms-lab/MP-DocVQA +task: "multidocvqa_val" +test_split: val +output_type: generate_until +doc_to_visual: !function utils.multidocvqa_doc_to_visual +doc_to_text: !function utils.multidocvqa_doc_to_text +doc_to_target: "answers" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +process_results: !function utils.multidocvqa_process_results +metric_list: + - metric: anls + aggregation: !function utils.multidocvqa_aggregate_results_anls + higher_is_better: true + - metric: accuracy + aggregation: !function utils.multidocvqa_aggregate_results_accuracy + higher_is_better: true +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "\nAnswer the question using a single word or phrase." diff --git a/lmms_eval/tasks/multidocvqa/utils.py b/lmms_eval/tasks/multidocvqa/utils.py new file mode 100644 index 00000000..33241cf3 --- /dev/null +++ b/lmms_eval/tasks/multidocvqa/utils.py @@ -0,0 +1,108 @@ +import os +import re +import ast +import json +from lmms_eval.api.metrics import levenshtein_distance + + +def multidocvqa_doc_to_text(doc, model_specific_prompt_kwargs): + question = doc["question"] + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + post_prompt = model_specific_prompt_kwargs["post_prompt"] + + return f"{pre_prompt}{question}{post_prompt}" + +def multidocvqa_doc_to_visual(doc): + return [doc[f'image_{i}'].convert("RGB") for i in range(1, 21) if doc[f'image_{i}'] is not None] + +def multidocvqa_process_results(doc, results): + pred_answer = results[0] + answer = ast.literal_eval(doc['answers']) + + return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}, + "accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}} + +def multidocvqa_aggregate_results_anls(results): + keys = {k for result in results for k in result} + results = {key: [result.get(key, None) for result in results] for key in keys} + evaluator = Evaluator(case_sensitive=False) + metric = evaluator.get_metrics(results['answer'], results['pred_answer']) + + return sum(metric['anls']) / len(metric['anls']) + +def multidocvqa_aggregate_results_accuracy(results): + keys = {k for result in results for k in result} + results = {key: [result.get(key, None) for result in results] for key in keys} + evaluator = Evaluator(case_sensitive=False) + metric = evaluator.get_metrics(results['answer'], results['pred_answer']) + + return sum(metric['accuracy']) / len(metric['accuracy']) + +def multidocvqa_process_test_results_for_submission(doc, results): + answer = results[0] + return {"submission": {"questionId": int(doc["questionId"]), "answer": answer, "answer_page": None}} + +def multidocvqa_test_aggregate_results_for_submission(results): + os.makedirs("./submissions", exist_ok=True) + with open("./submissions/multidocvqa_test_for_submission.json", "w") as f: + json.dump(results, f) + return -1 + + +################## +# Helper functions +################## + + +class Evaluator: + def __init__(self, case_sensitive=False): + + self.case_sensitive = case_sensitive + self.get_edit_distance = levenshtein_distance + self.anls_threshold = 0.5 + + def get_metrics(self, gt_answers, preds): + batch_accuracy = [] + batch_anls = [] + for batch_idx in range(len(preds)): + gt = [self._preprocess_str(gt_elm) for gt_elm in gt_answers[batch_idx]] + pred = self._preprocess_str(preds[batch_idx]) + + batch_accuracy.append(self._calculate_accuracy(gt, pred)) + batch_anls.append(self._calculate_anls(gt, pred)) + + return {'accuracy': batch_accuracy, 'anls': batch_anls} + + def _preprocess_str(self, string): + if not self.case_sensitive: + string = string.lower() + + return string.strip() + + def _calculate_accuracy(self, gt, pred): + + if pred == 'none': + return 0 + + for gt_elm in gt: + if gt_elm == pred: + return 1 + + return 0 + + def _calculate_anls(self, gt, pred): + if len(pred) == 0: + return 0 + + if pred == 'none': + return 0 + + answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt] + max_similarity = max(answers_similarity) + + anls = max_similarity if max_similarity >= self.anls_threshold else 0 + return anls + +if __name__ == '__main__': + print('-----------------') + multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}]) \ No newline at end of file diff --git a/lmms_eval/tasks/stvqa/stvqa.yaml b/lmms_eval/tasks/stvqa/stvqa.yaml new file mode 100644 index 00000000..42ecaaca --- /dev/null +++ b/lmms_eval/tasks/stvqa/stvqa.yaml @@ -0,0 +1,20 @@ +dataset_path: lmms-lab/ST-VQA +task: "stvqa" +test_split: test +output_type: generate_until +doc_to_visual: !function utils.stvqa_doc_to_visual +doc_to_text: !function utils.stvqa_doc_to_text +doc_to_target: "answers" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +process_results: !function utils.stvqa_process_results +metric_list: + - metric: submission + aggregation: !function utils.stvqa_aggregate_submissions +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "\nAnswer the question using a single word or phrase." + \ No newline at end of file diff --git a/lmms_eval/tasks/stvqa/utils.py b/lmms_eval/tasks/stvqa/utils.py new file mode 100644 index 00000000..1f106f06 --- /dev/null +++ b/lmms_eval/tasks/stvqa/utils.py @@ -0,0 +1,21 @@ +import os +import json + +def stvqa_doc_to_text(doc, model_specific_prompt_kwargs): + question = doc["question"] + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + post_prompt = model_specific_prompt_kwargs["post_prompt"] + return f"{pre_prompt}{question}{post_prompt}" + +def stvqa_doc_to_visual(doc): + return [doc["image"].convert("RGB")] + +def stvqa_process_results(doc, results): + answer = results[0] + return {"submission": {"question_id": int(doc["question_id"]), "answer": answer}} + +def stvqa_aggregate_submissions(results): + os.makedirs("./submissions", exist_ok=True) + with open("./submissions/stvqa_test_for_submission.json", "w") as f: + json.dump(results, f) + return -1 \ No newline at end of file