Skip to content

Commit

Permalink
add stvqa and multidocvqa (EvolvingLMMs-Lab#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
XinrunDu authored Feb 13, 2024
1 parent b5204d4 commit c6d4d44
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lmms_eval/tasks/multidocvqa/multidocvqa_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
dataset_path: lmms-lab/MP-DocVQA
task: "multidocvqa_test"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.multidocvqa_doc_to_visual
doc_to_text: !function utils.multidocvqa_doc_to_text
doc_to_target: "answers"
generation_kwargs:
max_new_tokens: 32
temperature: 0
do_sample: False
process_results: !function utils.multidocvqa_process_test_results_for_submission
metric_list:
- metric: submission
aggregation: !function utils.multidocvqa_test_aggregate_results_for_submission
model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word or phrase."

23 changes: 23 additions & 0 deletions lmms_eval/tasks/multidocvqa/multidocvqa_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
dataset_path: lmms-lab/MP-DocVQA
task: "multidocvqa_val"
test_split: val
output_type: generate_until
doc_to_visual: !function utils.multidocvqa_doc_to_visual
doc_to_text: !function utils.multidocvqa_doc_to_text
doc_to_target: "answers"
generation_kwargs:
max_new_tokens: 32
temperature: 0
do_sample: False
process_results: !function utils.multidocvqa_process_results
metric_list:
- metric: anls
aggregation: !function utils.multidocvqa_aggregate_results_anls
higher_is_better: true
- metric: accuracy
aggregation: !function utils.multidocvqa_aggregate_results_accuracy
higher_is_better: true
model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word or phrase."
108 changes: 108 additions & 0 deletions lmms_eval/tasks/multidocvqa/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import re
import ast
import json
from lmms_eval.api.metrics import levenshtein_distance


def multidocvqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]

return f"{pre_prompt}{question}{post_prompt}"

def multidocvqa_doc_to_visual(doc):
return [doc[f'image_{i}'].convert("RGB") for i in range(1, 21) if doc[f'image_{i}'] is not None]

def multidocvqa_process_results(doc, results):
pred_answer = results[0]
answer = ast.literal_eval(doc['answers'])

return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer},
"accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}}

def multidocvqa_aggregate_results_anls(results):
keys = {k for result in results for k in result}
results = {key: [result.get(key, None) for result in results] for key in keys}
evaluator = Evaluator(case_sensitive=False)
metric = evaluator.get_metrics(results['answer'], results['pred_answer'])

return sum(metric['anls']) / len(metric['anls'])

def multidocvqa_aggregate_results_accuracy(results):
keys = {k for result in results for k in result}
results = {key: [result.get(key, None) for result in results] for key in keys}
evaluator = Evaluator(case_sensitive=False)
metric = evaluator.get_metrics(results['answer'], results['pred_answer'])

return sum(metric['accuracy']) / len(metric['accuracy'])

def multidocvqa_process_test_results_for_submission(doc, results):
answer = results[0]
return {"submission": {"questionId": int(doc["questionId"]), "answer": answer, "answer_page": None}}

def multidocvqa_test_aggregate_results_for_submission(results):
os.makedirs("./submissions", exist_ok=True)
with open("./submissions/multidocvqa_test_for_submission.json", "w") as f:
json.dump(results, f)
return -1


##################
# Helper functions
##################


class Evaluator:
def __init__(self, case_sensitive=False):

self.case_sensitive = case_sensitive
self.get_edit_distance = levenshtein_distance
self.anls_threshold = 0.5

def get_metrics(self, gt_answers, preds):
batch_accuracy = []
batch_anls = []
for batch_idx in range(len(preds)):
gt = [self._preprocess_str(gt_elm) for gt_elm in gt_answers[batch_idx]]
pred = self._preprocess_str(preds[batch_idx])

batch_accuracy.append(self._calculate_accuracy(gt, pred))
batch_anls.append(self._calculate_anls(gt, pred))

return {'accuracy': batch_accuracy, 'anls': batch_anls}

def _preprocess_str(self, string):
if not self.case_sensitive:
string = string.lower()

return string.strip()

def _calculate_accuracy(self, gt, pred):

if pred == 'none':
return 0

for gt_elm in gt:
if gt_elm == pred:
return 1

return 0

def _calculate_anls(self, gt, pred):
if len(pred) == 0:
return 0

if pred == 'none':
return 0

answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt]
max_similarity = max(answers_similarity)

anls = max_similarity if max_similarity >= self.anls_threshold else 0
return anls

if __name__ == '__main__':
print('-----------------')
multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}])
20 changes: 20 additions & 0 deletions lmms_eval/tasks/stvqa/stvqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
dataset_path: lmms-lab/ST-VQA
task: "stvqa"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.stvqa_doc_to_visual
doc_to_text: !function utils.stvqa_doc_to_text
doc_to_target: "answers"
generation_kwargs:
max_new_tokens: 32
temperature: 0
do_sample: False
process_results: !function utils.stvqa_process_results
metric_list:
- metric: submission
aggregation: !function utils.stvqa_aggregate_submissions
model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word or phrase."

21 changes: 21 additions & 0 deletions lmms_eval/tasks/stvqa/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import json

def stvqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{question}{post_prompt}"

def stvqa_doc_to_visual(doc):
return [doc["image"].convert("RGB")]

def stvqa_process_results(doc, results):
answer = results[0]
return {"submission": {"question_id": int(doc["question_id"]), "answer": answer}}

def stvqa_aggregate_submissions(results):
os.makedirs("./submissions", exist_ok=True)
with open("./submissions/stvqa_test_for_submission.json", "w") as f:
json.dump(results, f)
return -1

0 comments on commit c6d4d44

Please sign in to comment.