forked from EvolvingLMMs-Lab/lmms-eval
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add stvqa and multidocvqa (EvolvingLMMs-Lab#46)
- Loading branch information
Showing
5 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
dataset_path: lmms-lab/MP-DocVQA | ||
task: "multidocvqa_test" | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.multidocvqa_doc_to_visual | ||
doc_to_text: !function utils.multidocvqa_doc_to_text | ||
doc_to_target: "answers" | ||
generation_kwargs: | ||
max_new_tokens: 32 | ||
temperature: 0 | ||
do_sample: False | ||
process_results: !function utils.multidocvqa_process_test_results_for_submission | ||
metric_list: | ||
- metric: submission | ||
aggregation: !function utils.multidocvqa_test_aggregate_results_for_submission | ||
model_specific_prompt_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "\nAnswer the question using a single word or phrase." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
dataset_path: lmms-lab/MP-DocVQA | ||
task: "multidocvqa_val" | ||
test_split: val | ||
output_type: generate_until | ||
doc_to_visual: !function utils.multidocvqa_doc_to_visual | ||
doc_to_text: !function utils.multidocvqa_doc_to_text | ||
doc_to_target: "answers" | ||
generation_kwargs: | ||
max_new_tokens: 32 | ||
temperature: 0 | ||
do_sample: False | ||
process_results: !function utils.multidocvqa_process_results | ||
metric_list: | ||
- metric: anls | ||
aggregation: !function utils.multidocvqa_aggregate_results_anls | ||
higher_is_better: true | ||
- metric: accuracy | ||
aggregation: !function utils.multidocvqa_aggregate_results_accuracy | ||
higher_is_better: true | ||
model_specific_prompt_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "\nAnswer the question using a single word or phrase." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
import re | ||
import ast | ||
import json | ||
from lmms_eval.api.metrics import levenshtein_distance | ||
|
||
|
||
def multidocvqa_doc_to_text(doc, model_specific_prompt_kwargs): | ||
question = doc["question"] | ||
pre_prompt = model_specific_prompt_kwargs["pre_prompt"] | ||
post_prompt = model_specific_prompt_kwargs["post_prompt"] | ||
|
||
return f"{pre_prompt}{question}{post_prompt}" | ||
|
||
def multidocvqa_doc_to_visual(doc): | ||
return [doc[f'image_{i}'].convert("RGB") for i in range(1, 21) if doc[f'image_{i}'] is not None] | ||
|
||
def multidocvqa_process_results(doc, results): | ||
pred_answer = results[0] | ||
answer = ast.literal_eval(doc['answers']) | ||
|
||
return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}, | ||
"accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}} | ||
|
||
def multidocvqa_aggregate_results_anls(results): | ||
keys = {k for result in results for k in result} | ||
results = {key: [result.get(key, None) for result in results] for key in keys} | ||
evaluator = Evaluator(case_sensitive=False) | ||
metric = evaluator.get_metrics(results['answer'], results['pred_answer']) | ||
|
||
return sum(metric['anls']) / len(metric['anls']) | ||
|
||
def multidocvqa_aggregate_results_accuracy(results): | ||
keys = {k for result in results for k in result} | ||
results = {key: [result.get(key, None) for result in results] for key in keys} | ||
evaluator = Evaluator(case_sensitive=False) | ||
metric = evaluator.get_metrics(results['answer'], results['pred_answer']) | ||
|
||
return sum(metric['accuracy']) / len(metric['accuracy']) | ||
|
||
def multidocvqa_process_test_results_for_submission(doc, results): | ||
answer = results[0] | ||
return {"submission": {"questionId": int(doc["questionId"]), "answer": answer, "answer_page": None}} | ||
|
||
def multidocvqa_test_aggregate_results_for_submission(results): | ||
os.makedirs("./submissions", exist_ok=True) | ||
with open("./submissions/multidocvqa_test_for_submission.json", "w") as f: | ||
json.dump(results, f) | ||
return -1 | ||
|
||
|
||
################## | ||
# Helper functions | ||
################## | ||
|
||
|
||
class Evaluator: | ||
def __init__(self, case_sensitive=False): | ||
|
||
self.case_sensitive = case_sensitive | ||
self.get_edit_distance = levenshtein_distance | ||
self.anls_threshold = 0.5 | ||
|
||
def get_metrics(self, gt_answers, preds): | ||
batch_accuracy = [] | ||
batch_anls = [] | ||
for batch_idx in range(len(preds)): | ||
gt = [self._preprocess_str(gt_elm) for gt_elm in gt_answers[batch_idx]] | ||
pred = self._preprocess_str(preds[batch_idx]) | ||
|
||
batch_accuracy.append(self._calculate_accuracy(gt, pred)) | ||
batch_anls.append(self._calculate_anls(gt, pred)) | ||
|
||
return {'accuracy': batch_accuracy, 'anls': batch_anls} | ||
|
||
def _preprocess_str(self, string): | ||
if not self.case_sensitive: | ||
string = string.lower() | ||
|
||
return string.strip() | ||
|
||
def _calculate_accuracy(self, gt, pred): | ||
|
||
if pred == 'none': | ||
return 0 | ||
|
||
for gt_elm in gt: | ||
if gt_elm == pred: | ||
return 1 | ||
|
||
return 0 | ||
|
||
def _calculate_anls(self, gt, pred): | ||
if len(pred) == 0: | ||
return 0 | ||
|
||
if pred == 'none': | ||
return 0 | ||
|
||
answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt] | ||
max_similarity = max(answers_similarity) | ||
|
||
anls = max_similarity if max_similarity >= self.anls_threshold else 0 | ||
return anls | ||
|
||
if __name__ == '__main__': | ||
print('-----------------') | ||
multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
dataset_path: lmms-lab/ST-VQA | ||
task: "stvqa" | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.stvqa_doc_to_visual | ||
doc_to_text: !function utils.stvqa_doc_to_text | ||
doc_to_target: "answers" | ||
generation_kwargs: | ||
max_new_tokens: 32 | ||
temperature: 0 | ||
do_sample: False | ||
process_results: !function utils.stvqa_process_results | ||
metric_list: | ||
- metric: submission | ||
aggregation: !function utils.stvqa_aggregate_submissions | ||
model_specific_prompt_kwargs: | ||
default: | ||
pre_prompt: "" | ||
post_prompt: "\nAnswer the question using a single word or phrase." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
import json | ||
|
||
def stvqa_doc_to_text(doc, model_specific_prompt_kwargs): | ||
question = doc["question"] | ||
pre_prompt = model_specific_prompt_kwargs["pre_prompt"] | ||
post_prompt = model_specific_prompt_kwargs["post_prompt"] | ||
return f"{pre_prompt}{question}{post_prompt}" | ||
|
||
def stvqa_doc_to_visual(doc): | ||
return [doc["image"].convert("RGB")] | ||
|
||
def stvqa_process_results(doc, results): | ||
answer = results[0] | ||
return {"submission": {"question_id": int(doc["question_id"]), "answer": answer}} | ||
|
||
def stvqa_aggregate_submissions(results): | ||
os.makedirs("./submissions", exist_ok=True) | ||
with open("./submissions/stvqa_test_for_submission.json", "w") as f: | ||
json.dump(results, f) | ||
return -1 |