diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py
index a6285bcd..8fd1b4ae 100644
--- a/lmms_eval/__main__.py
+++ b/lmms_eval/__main__.py
@@ -172,17 +172,21 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
is_main_process = False
for args in args_list:
- if is_main_process:
- wandb_logger = WandbLogger(args)
- results = cli_evaluate_single(args)
-
- accelerator.wait_for_everyone()
- if is_main_process:
- wandb_logger.log_eval_result(results)
- if wandb_logger.online():
- wandb_logger.write_to_report(results)
- wandb_logger.finish()
- results_list.append(results)
+ try:
+ if is_main_process and args.wandb_args:
+ wandb_logger = WandbLogger(args)
+ results = cli_evaluate_single(args)
+
+ accelerator.wait_for_everyone()
+ if is_main_process and args.wandb_args and results is not None:
+ wandb_logger.log_eval_result(results)
+ if wandb_logger.online():
+ wandb_logger.write_to_report()
+ wandb_logger.finish()
+ results_list.append(results)
+ except Exception as e:
+ eval_logger.error(f"Error during evaluation: {e}")
+ results_list.append(None)
for args, results in zip(args_list, results_list):
# cli_evaluate will return none if the process is not the main process (rank 0)
@@ -209,7 +213,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
elif args.tasks == "list":
eval_logger.info("Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS))))
sys.exit()
- elif args.tasks == "list_tasks_num":
+ elif args.tasks == "list_with_num":
log_message = (
"\n" + "=" * 70 + "\n" + "\n\tYou are trying to check all the numbers in each task." + "\n\tThis action will download the complete dataset." + "\n\tIf the results are not clear initially, call this again." + "\n\n" + "=" * 70
)
diff --git a/lmms_eval/logging_utils.py b/lmms_eval/logging_utils.py
index 15b7cccc..f09bbaa6 100644
--- a/lmms_eval/logging_utils.py
+++ b/lmms_eval/logging_utils.py
@@ -71,11 +71,6 @@ def init_run(self):
# initialize a W&B run
self.run = wandb.init(**self.wandb_args)
- # call wr inside the init_run method to avoid multiple times logging
- import wandb.apis.reports as wr
-
- self.wr = wr
-
def log_eval_result(self, results):
# Log configs to wandb
configs = self.get_config(results)
@@ -217,80 +212,71 @@ def sanitize_results_dict(self):
return wandb_summary, _results
def prepare_report_by_task(self, results):
- task_names = list(results.get("results", {}).keys())
+ import wandb.apis.reports as wr
+
+ task_names = list(self.results.get("results", {}).keys())
blocks = []
for task_name in task_names:
- blocks.append(self.wr.H2(task_name))
+ blocks.append(wr.H2(task_name))
panels = []
for metric_name, metric_value in results.items():
if task_name in metric_name:
panels.append(
- self.wr.ScalarChart(
+ wr.ScalarChart(
title=f"{metric_name}",
metric=f"{metric_name}",
font_size="large",
)
)
_results = {
- "results": {f"{task_name}": results.get("results").get(task_name)},
- "versions": {f"{task_name}": results.get("versions").get(task_name)},
- "n-shot": {f"{task_name}": results.get("n-shot").get(task_name)},
+ "results": {f"{task_name}": self.results.get("results").get(task_name)},
+ "versions": {f"{task_name}": self.results.get("versions").get(task_name)},
+ "n-shot": {f"{task_name}": self.results.get("n-shot").get(task_name)},
}
results_md = utils.make_table(_results)
- blocks.extend([self.wr.MarkdownBlock(results_md), self.wr.PanelGrid(panels=panels)])
- # blocks.extend([
- # self.wr.WeaveBlockSummaryTable(
- # project=self.run.project,
- # entity=self.run.entity,
- # table_name=f"{task_name}_eval_results",
- # ),
- # self.wr.PanelGrid(
- # runsets=[
- # self.wr.Runset(
- # project=self.run.project, entity=self.run.entity,
- # ).set_filters_with_python_expr(f'Name == "{str(self.run.name)}"'),
- # ]
- # ),
- # ])
+ blocks.extend([wr.MarkdownBlock(results_md), wr.PanelGrid(panels=panels)])
+ # TODO: Add results table
return blocks
- def write_to_report(self, results):
- wandb_project = self.run.project
- wandb_entity = self.run.entity
- report = self.wr.Report(
+ def write_to_report(self):
+ import wandb.apis.reports as wr
+
+ report = wr.Report(
project=self.run.project,
entity=self.run.entity,
title=f"({datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) {self.run.id} - Evaluation report",
description=f"Evaluation run by: {self.run.entity} logged to {self.run.url}",
)
- results_md = utils.make_table(results)
task_blocks = self.prepare_report_by_task(self.wandb_results)
blocks = (
[
- self.wr.TableOfContents(),
- self.wr.H1("Complete Evaluation Results"),
- self.wr.WeaveBlockSummaryTable(
+ wr.TableOfContents(),
+ wr.H1("Complete Evaluation Results"),
+ wr.WeaveBlockSummaryTable(
project=self.run.project,
entity=self.run.entity,
- table_name=f"evaluation/eval_results",
+ table_name="evaluation/eval_results",
),
- self.wr.PanelGrid(
+ wr.PanelGrid(
runsets=[
- self.wr.Runset(
+ wr.Runset(
project=self.run.project,
entity=self.run.entity,
).set_filters_with_python_expr(f'Name == "{str(self.run.name)}"'),
]
),
- self.wr.H1("Evaluation Results By Task"),
+ wr.H1("Evaluation Results By Task"),
]
+ task_blocks
+ [
- self.wr.H1("Evaluation Config"),
- self.wr.CodeBlock(json.dumps(self.results["config"], indent=5).split("\n"), language="json"),
+ wr.H1("Evaluation Config"),
+ wr.CodeBlock(
+ json.dumps(self.results["config"], indent=5).split("\n"),
+ language="json",
+ ),
# TODO: Add appendix
]
)
diff --git a/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py b/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py
index f198c267..b012c8f2 100644
--- a/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py
+++ b/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py
@@ -47,9 +47,7 @@ def get_ltor_masks_and_position_ids(
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
- attention_mask = torch.tril(
- torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
- ).view(att_mask_batch, 1, seq_length, seq_length)
+ attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
@@ -66,7 +64,6 @@ def get_ltor_masks_and_position_ids(
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
-
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
@@ -134,9 +131,7 @@ def make_context(
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
- return f"{role}\n{content}", tokenizer.encode(
- role, allowed_special=set(tokenizer.IMAGE_ST)
- ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
+ return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set(tokenizer.IMAGE_ST)) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
@@ -148,22 +143,16 @@ def _tokenize_str(role, content):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
- response_text, response_tokens_part = _tokenize_str(
- "assistant", turn_response
- )
+ response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
- prev_chat = (
- f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
- )
+ prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"
- current_context_size = (
- len(system_tokens) + len(next_context_tokens) + len(context_tokens)
- )
+ current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
@@ -172,16 +161,7 @@ def _tokenize_str(role, content):
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
- context_tokens += (
- nl_tokens
- + im_start_tokens
- + _tokenize_str("user", query)[1]
- + im_end_tokens
- + nl_tokens
- + im_start_tokens
- + tokenizer.encode("assistant")
- + nl_tokens
- )
+ context_tokens += nl_tokens + im_start_tokens + _tokenize_str("user", query)[1] + im_end_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
@@ -202,7 +182,7 @@ def _decode_default(
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
- errors: str='replace',
+ errors: str = "replace",
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
@@ -227,16 +207,7 @@ def _decode_default(
def _decode_chatml(
- tokens: List[int],
- *,
- stop_words: List[str],
- eod_token_ids: List[int],
- tokenizer: PreTrainedTokenizer,
- raw_text_len: int,
- context_length: int,
- verbose: bool = False,
- return_end_reason: bool = False,
- errors: str='replace'
+ tokens: List[int], *, stop_words: List[str], eod_token_ids: List[int], tokenizer: PreTrainedTokenizer, raw_text_len: int, context_length: int, verbose: bool = False, return_end_reason: bool = False, errors: str = "replace"
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
@@ -270,7 +241,7 @@ def decode_tokens(
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
- errors: str="replace",
+ errors: str = "replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
@@ -315,42 +286,19 @@ class StopWordsLogitsProcessor(LogitsProcessor):
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
-
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
- raise ValueError(
- f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
- )
+ raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
- raise ValueError(
- f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
- )
- if any(
- any(
- (not isinstance(token_id, (int, np.integer)) or token_id < 0)
- for token_id in stop_word_ids
- )
- for stop_word_ids in stop_words_ids
- ):
- raise ValueError(
- f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
- )
-
- self.stop_words_ids = list(
- filter(
- lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
- )
- )
+ raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
+ if any(any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) for stop_word_ids in stop_words_ids):
+ raise ValueError(f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}.")
+
+ self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
- assert (
- len(stop_token_seq) > 0
- ), "Stop words token sequences {} cannot have an empty list".format(
- stop_words_ids
- )
-
- def __call__(
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
- ) -> torch.FloatTensor:
+ assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(stop_words_ids)
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
@@ -416,4 +364,4 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
- return (1 - boolean) * val1 + boolean * val2
\ No newline at end of file
+ return (1 - boolean) * val1 + boolean * val2
diff --git a/lmms_eval/models/otterhd.py b/lmms_eval/models/otterhd.py
index e9fa4635..70d6d620 100644
--- a/lmms_eval/models/otterhd.py
+++ b/lmms_eval/models/otterhd.py
@@ -172,7 +172,9 @@ def _collate(x):
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
- generation_output = self.model.generate(**model_inputs, max_new_tokens=gen_kwargs["max_new_tokens"], pad_token_id=self.tokenizer.eos_token_id)
+ generation_output = self.model.generate(
+ **model_inputs, temperature=gen_kwargs["temperature"], max_new_tokens=gen_kwargs["max_new_tokens"], top_p=gen_kwargs["top_p"], num_beams=gen_kwargs["num_beams"], pad_token_id=self.tokenizer.eos_token_id
+ )
generation_texts = self.processor.batch_decode(generation_output, skip_special_tokens=True)
response = [gen_text.split("\x04")[1].strip(" ").strip("\n") for gen_text in generation_texts]
res.extend(response)
diff --git a/lmms_eval/models/qwen_vl.py b/lmms_eval/models/qwen_vl.py
index 01334c51..c3b6c7e1 100644
--- a/lmms_eval/models/qwen_vl.py
+++ b/lmms_eval/models/qwen_vl.py
@@ -133,46 +133,30 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
name = uuid.uuid4().hex.upper()[0:6]
visual.save(f"/tmp/{name}.png")
visual_paths.append(f"/tmp/{name}.png")
- query.append({
- 'image' : f"/tmp/{name}.png"
- })
-
+ query.append({"image": f"/tmp/{name}.png"})
+
# Make a copy for query to save context (text that needs to be masked)
- context_query = [ _ for _ in query]
- context_query.append({'text' : contexts})
- query.append({'text' : contexts + continuation})
-
+ context_query = [_ for _ in query]
+ context_query.append({"text": contexts})
+ query.append({"text": contexts + continuation})
+
context_query = self.tokenizer.from_list_format(context_query)
query = self.tokenizer.from_list_format(query)
-
+
raw_contxt_text, context_tokens = make_context(
- self.tokenizer,
- context_query,
- history=None,
- system='You are a helpful assistant',
- max_window_size=self.model.generation_config.max_window_size,
- chat_format=self.model.generation_config.chat_format
- )
+ self.tokenizer, context_query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format
+ )
context_tokens = torch.tensor([context_tokens])
-
+
raw_continuation_text, continuation_tokens = make_context(
- self.tokenizer,
- query,
- history=None,
- system='You are a helpful assistant',
- max_window_size=self.model.generation_config.max_window_size,
- chat_format=self.model.generation_config.chat_format
- )
+ self.tokenizer, query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format
+ )
continuation_tokens = torch.tensor([continuation_tokens]).to(self.model.device)
attn_mask = torch.ones_like(continuation_tokens).to(self.model.device)
labels = continuation_tokens.clone().to(self.model.device)
- labels[:, :context_tokens.shape[1]] = -100
+ labels[:, : context_tokens.shape[1]] = -100
with torch.inference_mode():
- outputs = self.model(
- input_ids = continuation_tokens,
- labels=labels,
- attention_mask = attn_mask
- )
+ outputs = self.model(input_ids=continuation_tokens, labels=labels, attention_mask=attn_mask)
loss = outputs.loss
logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1)
@@ -181,10 +165,9 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal)))
pbar.update(1)
-
+
pbar.close()
return res
-
assert False, "We have not implemented this function for Qwen VL yet"
diff --git a/lmms_eval/tasks/cmmmu/utils.py b/lmms_eval/tasks/cmmmu/utils.py
index e5899677..ca4e7b31 100644
--- a/lmms_eval/tasks/cmmmu/utils.py
+++ b/lmms_eval/tasks/cmmmu/utils.py
@@ -6,38 +6,39 @@
from collections import Counter
PROMPT = {
- 'task_instructions': [
- '请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。',
- '请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。',
- '请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
+ "task_instructions": [
+ "请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。",
+ "请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。",
+ "请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。",
],
- 'multi_choice_example_format': ['问题:{}\n选项:\n{}\n正确答案:\n'],
- 'T/F_example_format': ['问题:{}\n正确答案:\n'],
- 'short_ans_example_format': ['问题:{}\n正确答案:\n'],
+ "multi_choice_example_format": ["问题:{}\n选项:\n{}\n正确答案:\n"],
+ "T/F_example_format": ["问题:{}\n正确答案:\n"],
+ "short_ans_example_format": ["问题:{}\n正确答案:\n"],
}
+
def construct_prompt(sample):
- question = sample['question']
- task_instructions = PROMPT['task_instructions']
-
- if sample['type'] == '选择':
+ question = sample["question"]
+ task_instructions = PROMPT["task_instructions"]
+
+ if sample["type"] == "选择":
formatted_options = ""
- start_chr = 'A'
+ start_chr = "A"
for i in range(1, 5):
formatted_options += f"({start_chr}) {sample[f'option{i}']}\n"
start_chr = chr(ord(start_chr) + 1)
-
- current_example_template = PROMPT['multi_choice_example_format'][0]
+
+ current_example_template = PROMPT["multi_choice_example_format"][0]
current_example = current_example_template.format(question, formatted_options)
final_input_prompt = task_instructions[0] + "\n\n" + current_example
-
- elif sample['type'] == '判断':
- current_example_template = PROMPT['T/F_example_format'][0]
+
+ elif sample["type"] == "判断":
+ current_example_template = PROMPT["T/F_example_format"][0]
current_example = current_example_template.format(question)
final_input_prompt = task_instructions[1] + "\n\n" + current_example
-
+
else: # For fill in the blanks questions.
- current_example_template = PROMPT['short_ans_example_format'][0]
+ current_example_template = PROMPT["short_ans_example_format"][0]
current_example = current_example_template.format(question)
final_input_prompt = task_instructions[2] + "\n\n" + current_example
@@ -46,9 +47,11 @@ def construct_prompt(sample):
return final_input_prompt
+
def cmmmu_doc_to_text(doc):
return construct_prompt(doc)
+
def cmmmu_doc_to_visual(doc):
prompt = construct_prompt(doc)
image_tokens = re.findall(r"<图片 \d+>", prompt)
@@ -57,6 +60,7 @@ def cmmmu_doc_to_visual(doc):
visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
return visual
+
def cmmmu_process_results(doc, results):
pred = results[0]
if doc["type"] == "选择":
@@ -68,6 +72,7 @@ def cmmmu_process_results(doc, results):
parsed_pred = get_fill_blank_prediction(pred, doc["answer"])
return {"cmmmu_acc": {"id": doc["id"], "subdomain": doc["subcategory"], "question_type": doc["type"], "answer": doc["answer"], "parsed_pred": parsed_pred}}
+
def cmmmu_aggregate_results(results):
evaluation_result = {}
subset_to_eval_samples = defaultdict(list)
@@ -105,16 +110,18 @@ def cmmmu_aggregate_results(results):
print(printable_results)
return printable_results["Overall"]["acc"]
+
def cmmmu_process_test_results_for_submission(doc, results):
response = results[0]
return {"cmmmu_acc": {"id": doc["id"], "type": doc["type"], "response": response}}
+
def cmmmu_test_aggregate_results_for_submission(results):
os.makedirs("./submissions", exist_ok=True)
with open("./submissions/cmmmu_test_for_submission.jsonl", "w") as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
- f.write('\n')
+ f.write("\n")
return -1
@@ -128,26 +135,26 @@ def cmmmu_test_aggregate_results_for_submission(results):
"科学": ["生物", "化学", "地理", "数学", "物理"],
"健康与医学": ["基础医学", "临床医学", "诊断学与实验室医学", "制药", "公共卫生"],
"人文社会科学": ["历史", "文献学", "社会学", "心理学"],
- "技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"]
+ "技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"],
}
-def eval_cmmmu(entries):
+def eval_cmmmu(entries):
correct_cnt = 0
for entry in entries:
- parsed_pred = entry.get('parsed_pred', '')
+ parsed_pred = entry.get("parsed_pred", "")
correct = False
- if entry.get('question_type') == '选择':
- if parsed_pred == entry['answer']:
+ if entry.get("question_type") == "选择":
+ if parsed_pred == entry["answer"]:
correct_cnt += 1
correct = True
- elif entry.get('question_type') == '填空':
- norm_answers = normalize_str(entry['answer'], entry['answer'])
+ elif entry.get("question_type") == "填空":
+ norm_answers = normalize_str(entry["answer"], entry["answer"])
for pred in parsed_pred:
# already normalized
- if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
# print(norm_ans, pred)
@@ -156,7 +163,7 @@ def eval_cmmmu(entries):
correct_cnt += 1
correct = True
break
- else: # it's a number
+ else: # it's a number
if pred in norm_answers:
if not correct:
correct_cnt += 1
@@ -164,9 +171,10 @@ def eval_cmmmu(entries):
break
else:
- positive_keywords = ['正确', '对', '准确', '肯定', '对的']
- negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
- ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
+ positive_keywords = ["正确", "对", "准确", "肯定", "对的"]
+ negative_keywords = ["不对", "错误", "不正确", "不准确", "不合适", "否定", "错的", "错"]
+ ambiguous_keywords = ["对错", "是否正确", "否正确", "或者", "是否", "正确性", "对不"]
+
def judge_similarity(pred_list, positive_keywords, negative_keywords):
positive_count = 0
negative_count = 0
@@ -182,50 +190,43 @@ def judge_similarity(pred_list, positive_keywords, negative_keywords):
elif negative_count > positive_count:
return "错"
else:
- return random.choice(['对', '错'])
- answer = entry['answer']
+ return random.choice(["对", "错"])
+
+ answer = entry["answer"]
parsed_pred = [word for word in parsed_pred if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
result = judge_similarity(parsed_pred, positive_keywords, negative_keywords)
if result == answer:
correct_cnt += 1
correct = True
if correct:
- entry['judge'] = '正确'
+ entry["judge"] = "正确"
else:
- entry['judge'] = '错误'
+ entry["judge"] = "错误"
if len(entries) == 0:
- print('entries_num == 0, please check your file')
- results_count = {
- 'correct_num': 0,
- 'entries_num': 0,
- 'acc': 0
- }
+ print("entries_num == 0, please check your file")
+ results_count = {"correct_num": 0, "entries_num": 0, "acc": 0}
else:
- results_count = {
- 'correct_num': correct_cnt,
- 'entries_num': len(entries),
- 'acc': correct_cnt / len(entries)
- }
+ results_count = {"correct_num": correct_cnt, "entries_num": len(entries), "acc": correct_cnt / len(entries)}
return results_count
def get_multi_choice_prediction(response, all_choices, index2ans):
- for char in [',', '.', '!', '?', ';', ':',"'"]:
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
- response = " " + response + " " # add space to avoid partial match
+ response = " " + response + " " # add space to avoid partial match
candidates = []
for choice in all_choices: # (A) (B) (C) (D)
# Add the choice to candidates each time it appears in the response
- candidates.extend([choice for _ in range(response.count(f'({choice})'))])
+ candidates.extend([choice for _ in range(response.count(f"({choice})"))])
if len(candidates) == 0:
for choice in all_choices: # A B C D
# Similarly, add the choice for each occurrence
- candidates.extend([choice for _ in range(response.count(f'{choice}'))])
+ candidates.extend([choice for _ in range(response.count(f"{choice}"))])
if len(candidates) == 0 and len(response.split()) >= 1:
for index, ans in index2ans.items():
@@ -237,7 +238,7 @@ def get_multi_choice_prediction(response, all_choices, index2ans):
for index, ans in index2ans.items():
if ans in response:
candidates.append(index)
- index_ans = False # it's content ans.
+ index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
return random.choice(all_choices)
@@ -251,15 +252,16 @@ def get_multi_choice_prediction(response, all_choices, index2ans):
most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
# Combine the most frequent candidates in ABCD order
- return ''.join(most_frequent_candidates)
+ return "".join(most_frequent_candidates)
+
def extract_numbers(string):
# Pattern for numbers with Chinese commas
- pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
+ pattern_commas = r"-?\d{1,3}(?:,\d{3})+"
# Pattern for scientific notation
- pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
+ pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
# Pattern for simple numbers without Chinese commas
- pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
+ pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)"
# Extract numbers with Chinese commas
numbers_with_commas = re.findall(pattern_commas, string)
@@ -272,16 +274,19 @@ def extract_numbers(string):
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
+
def check_is_number(string):
try:
- float(string.replace(',', ''))
+ float(string.replace(",", ""))
return True
except ValueError:
# check if there's comma inside
return False
+
def count_letters(string):
- return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
+ return sum(c.isalpha() and "a" <= c <= "z" or "A" <= c <= "Z" for c in string)
+
def normalize_str(string, answer):
# check if characters in the string
@@ -294,16 +299,17 @@ def normalize_str(string, answer):
is_number = check_is_number(string)
if is_number:
- string = string.replace(',', '')
+ string = string.replace(",", "")
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
- else: # it's likely to be a string
+ else: # it's likely to be a string
if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
return []
return [string]
+
def get_fill_blank_prediction(response, answer):
"""get the prediction from the generated response,
return a list of predicted strings or numbers"""
@@ -311,15 +317,14 @@ def get_fill_blank_prediction(response, answer):
def get_key_subresponses(response):
key_responses = []
response = response.strip("。").strip()
- sub_responses = re.split(r'。|\n', response)
- indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
- '正确答案', '因此', '最后', '答案', '结果']
+ sub_responses = re.split(r"。|\n", response)
+ indicators_of_keys = ["是", "为", "所以", "等于", "方案", "选择", "正确答案", "因此", "最后", "答案", "结果"]
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
- indicators_of_keys.extend(['='])
- shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
+ indicators_of_keys.extend(["="])
+ shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
@@ -327,18 +332,18 @@ def get_key_subresponses(response):
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
-
+
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
- if len(key_responses) == 0: # did not found any
+ if len(key_responses) == 0: # did not found any
return [response]
return key_responses
key_responses = get_key_subresponses(response)
- pred_list = key_responses.copy() # keep the original string response
+ pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
@@ -352,6 +357,7 @@ def get_key_subresponses(response):
return pred_list
+
def get_TF_prediction(response):
"""get the prediction from the generated response,
return a list of predicted strings or numbers"""
@@ -359,12 +365,11 @@ def get_TF_prediction(response):
def get_key_subresponses(response):
key_responses = []
response = response.strip("。").strip()
- sub_responses = re.split(r'。|\n', response)
- indicators_of_keys = ['是', '为', '所以', '判断',
- '陈述', '说法', '表达', '答案', '结果']
+ sub_responses = re.split(r"。|\n", response)
+ indicators_of_keys = ["是", "为", "所以", "判断", "陈述", "说法", "表达", "答案", "结果"]
key_responses = []
for index, resp in enumerate(sub_responses):
- shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
+ shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
@@ -372,25 +377,25 @@ def get_key_subresponses(response):
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
-
+
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
- if len(key_responses) == 0: # did not found any
+ if len(key_responses) == 0: # did not found any
return [response]
return key_responses
key_responses = get_key_subresponses(response)
- pred_list = key_responses.copy() # keep the original string response
+ pred_list = key_responses.copy() # keep the original string response
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
-def get_multi_choice_info(options):
+def get_multi_choice_info(options):
start_chr = "A"
all_choices = []
index2ans = {}
@@ -400,8 +405,8 @@ def get_multi_choice_info(options):
return index2ans, all_choices
-def calculate_ins_level_acc(results):
+def calculate_ins_level_acc(results):
correct_sum = 0
entries_sum = 0
for cat_results in results.values():
@@ -409,4 +414,4 @@ def calculate_ins_level_acc(results):
entries_sum += cat_results["entries_num"]
if entries_sum == 0:
return 0
- return correct_sum / entries_sum
\ No newline at end of file
+ return correct_sum / entries_sum
diff --git a/lmms_eval/tasks/dc100_en/utils.py b/lmms_eval/tasks/dc100_en/utils.py
index 318f4c21..d3b47ab7 100644
--- a/lmms_eval/tasks/dc100_en/utils.py
+++ b/lmms_eval/tasks/dc100_en/utils.py
@@ -2,6 +2,7 @@
import requests
import re
import logging
+import time
import os
import yaml
from pathlib import Path
@@ -36,7 +37,7 @@ def doc_to_visual(doc):
Provide a few lines for explanation and the rate number at last after "Final Score:"."""
-def get_chat_response(base64_image, prompt, max_retries=3):
+def get_chat_response(base64_image, prompt, max_retries=5, wait_time=10):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
@@ -68,8 +69,13 @@ def get_chat_response(base64_image, prompt, max_retries=3):
return response_data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}")
+ time.sleep(wait_time)
if attempt == max_retries - 1:
- raise
+ eval_logger.error(f"Failed to get response after {max_retries} attempts")
+ return ""
+ except Exception as e:
+ eval_logger.error(f"Error on attempt {attempt+1}: {e}")
+ return ""
def image_to_base64(pil_image):
diff --git a/lmms_eval/tasks/dc200_cn/utils.py b/lmms_eval/tasks/dc200_cn/utils.py
index bd00f4b5..1e729150 100644
--- a/lmms_eval/tasks/dc200_cn/utils.py
+++ b/lmms_eval/tasks/dc200_cn/utils.py
@@ -6,6 +6,7 @@
import yaml
from pathlib import Path
from io import BytesIO
+import time
def doc_to_visual(doc):
@@ -36,7 +37,7 @@ def doc_to_visual(doc):
Provide a few lines for explanation and the rate number at last after "Final Score:"."""
-def get_chat_response(base64_image, prompt, max_retries=3):
+def get_chat_response(base64_image, prompt, max_retries=5, wait_time=10):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
@@ -68,8 +69,13 @@ def get_chat_response(base64_image, prompt, max_retries=3):
return response_data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}")
+ time.sleep(wait_time)
if attempt == max_retries - 1:
- raise
+ eval_logger.error(f"Failed to get response after {max_retries} attempts")
+ return ""
+ except Exception as e:
+ eval_logger.error(f"Error on attempt {attempt+1}: {e}")
+ return ""
def image_to_base64(pil_image):
diff --git a/lmms_eval/tasks/iconqa/iconqa_test.yaml b/lmms_eval/tasks/iconqa/iconqa_test.yaml
new file mode 100644
index 00000000..58aa9e70
--- /dev/null
+++ b/lmms_eval/tasks/iconqa/iconqa_test.yaml
@@ -0,0 +1,23 @@
+dataset_path: lmms-lab/ICON-QA
+dataset_kwargs:
+ token: True
+task: "iconqa_test"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.doc_to_visual
+doc_to_text: !function utils.doc_to_text
+doc_to_target: "answers"
+generation_kwargs:
+ max_new_tokens: 32
+ temperature: 0
+ do_sample: False
+metric_list:
+ - metric: anls
+ aggregation: mean
+ higher_is_better: true
+model_specific_prompt_kwargs:
+ default:
+ pre_prompt: ""
+ statement: "Given a set of images and a question, please provide the answer to the question.\n"
+ options_statement: "Question: {question}.\nOptions:\n{options}\nPlease answer with the option letter from the given choices directly."
+ freeform_statement: "Question: {question}.\nPlease answer the question using a single word or phrase."
\ No newline at end of file
diff --git a/lmms_eval/tasks/iconqa/iconqa_val.yaml b/lmms_eval/tasks/iconqa/iconqa_val.yaml
new file mode 100644
index 00000000..7215715e
--- /dev/null
+++ b/lmms_eval/tasks/iconqa/iconqa_val.yaml
@@ -0,0 +1,23 @@
+dataset_path: lmms-lab/ICON-QA
+dataset_kwargs:
+ token: True
+task: "iconqa_val"
+test_split: val
+output_type: generate_until
+doc_to_visual: !function utils.doc_to_visual
+doc_to_text: !function utils.doc_to_text
+doc_to_target: "answers"
+generation_kwargs:
+ max_new_tokens: 32
+ temperature: 0
+ do_sample: False
+metric_list:
+ - metric: anls
+ aggregation: mean
+ higher_is_better: true
+model_specific_prompt_kwargs:
+ default:
+ pre_prompt: ""
+ statement: "Given a set of images and a question, please provide the answer to the question.\n"
+ options_statement: "Question: {question}.\nOptions:\n{options}\nPlease answer with the option letter from the given choices directly."
+ freeform_statement: "Question: {question}.\nPlease answer the question using a single word or phrase."
\ No newline at end of file
diff --git a/lmms_eval/tasks/iconqa/utils.py b/lmms_eval/tasks/iconqa/utils.py
new file mode 100644
index 00000000..5d833fec
--- /dev/null
+++ b/lmms_eval/tasks/iconqa/utils.py
@@ -0,0 +1,62 @@
+import json
+import os
+
+
+def options_to_str(options_prompt):
+ option_prompt_str = ""
+ for i, option in enumerate(options_prompt):
+ option_choice = chr(ord("A") + i)
+ option_prompt_str += f"{option_choice}. {option}\n"
+
+ option_prompt_str = option_prompt_str.rstrip("\n")
+ return option_prompt_str
+
+
+def doc_to_visual(doc):
+ image_list = []
+ if "query_image" in doc:
+ image_list.append(doc["query_image"].convert("RGB"))
+ if "choice_image_0" in doc:
+ image_list.append(doc["choice_image_0"].convert("RGB"))
+ if "choice_image_1" in doc:
+ image_list.append(doc["choice_image_1"].convert("RGB"))
+ if "choice_image_2" in doc:
+ image_list.append(doc["choice_image_2"].convert("RGB"))
+ if "choice_image_3" in doc:
+ image_list.append(doc["choice_image_3"].convert("RGB"))
+ if "choice_image_4" in doc:
+ image_list.append(doc["choice_image_4"].convert("RGB"))
+ assert len(image_list) < 6, "Maximum 5 images allowed for ICON-QA"
+ return image_list
+
+
+def doc_to_text(doc, model_specific_prompt_kwargs):
+ question = doc["question"]
+ ques_type = doc["ques_type"]
+ options_prompt = []
+
+ if ques_type == "choose_img":
+ options_prompt.append("The first image.")
+ options_prompt.append("The second image.")
+
+ options_str = options_to_str(options_prompt)
+ full_prompt = f"{model_specific_prompt_kwargs['pre_prompt']}{model_specific_prompt_kwargs['statement']}{model_specific_prompt_kwargs['options_statement'].format(question=question, options=options_str)}"
+
+ elif ques_type == "choose_txt":
+ choices = doc["choices"].split(",")
+ for i, choice in enumerate(choices):
+ options_prompt.append(f"{choice}")
+
+ options_str = options_to_str(options_prompt)
+ full_prompt = f"{model_specific_prompt_kwargs['pre_prompt']}{model_specific_prompt_kwargs['statement']}{model_specific_prompt_kwargs['options_statement'].format(question=question, options=options_str)}"
+
+ elif ques_type == "fill_in_blank":
+ full_prompt = f"{model_specific_prompt_kwargs['pre_prompt']}{model_specific_prompt_kwargs['statement']}{model_specific_prompt_kwargs['freeform_statement'].format(question=question)}"
+
+ return full_prompt
+
+
+def test_process_results(doc, results):
+ pred = results[0]
+ questionId = doc["questionId"]
+ return {"anls": {"questionId": int(questionId), "answer": pred}}
diff --git a/lmms_eval/tasks/multidocvqa/utils.py b/lmms_eval/tasks/multidocvqa/utils.py
index 33241cf3..698e25cc 100644
--- a/lmms_eval/tasks/multidocvqa/utils.py
+++ b/lmms_eval/tasks/multidocvqa/utils.py
@@ -12,36 +12,41 @@ def multidocvqa_doc_to_text(doc, model_specific_prompt_kwargs):
return f"{pre_prompt}{question}{post_prompt}"
+
def multidocvqa_doc_to_visual(doc):
- return [doc[f'image_{i}'].convert("RGB") for i in range(1, 21) if doc[f'image_{i}'] is not None]
+ return [doc[f"image_{i}"].convert("RGB") for i in range(1, 21) if doc[f"image_{i}"] is not None]
+
def multidocvqa_process_results(doc, results):
pred_answer = results[0]
- answer = ast.literal_eval(doc['answers'])
+ answer = ast.literal_eval(doc["answers"])
+
+ return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}, "accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}}
- return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer},
- "accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}}
def multidocvqa_aggregate_results_anls(results):
keys = {k for result in results for k in result}
results = {key: [result.get(key, None) for result in results] for key in keys}
evaluator = Evaluator(case_sensitive=False)
- metric = evaluator.get_metrics(results['answer'], results['pred_answer'])
+ metric = evaluator.get_metrics(results["answer"], results["pred_answer"])
+
+ return sum(metric["anls"]) / len(metric["anls"])
- return sum(metric['anls']) / len(metric['anls'])
def multidocvqa_aggregate_results_accuracy(results):
keys = {k for result in results for k in result}
results = {key: [result.get(key, None) for result in results] for key in keys}
evaluator = Evaluator(case_sensitive=False)
- metric = evaluator.get_metrics(results['answer'], results['pred_answer'])
+ metric = evaluator.get_metrics(results["answer"], results["pred_answer"])
+
+ return sum(metric["accuracy"]) / len(metric["accuracy"])
- return sum(metric['accuracy']) / len(metric['accuracy'])
def multidocvqa_process_test_results_for_submission(doc, results):
answer = results[0]
return {"submission": {"questionId": int(doc["questionId"]), "answer": answer, "answer_page": None}}
+
def multidocvqa_test_aggregate_results_for_submission(results):
os.makedirs("./submissions", exist_ok=True)
with open("./submissions/multidocvqa_test_for_submission.json", "w") as f:
@@ -56,7 +61,6 @@ def multidocvqa_test_aggregate_results_for_submission(results):
class Evaluator:
def __init__(self, case_sensitive=False):
-
self.case_sensitive = case_sensitive
self.get_edit_distance = levenshtein_distance
self.anls_threshold = 0.5
@@ -71,7 +75,7 @@ def get_metrics(self, gt_answers, preds):
batch_accuracy.append(self._calculate_accuracy(gt, pred))
batch_anls.append(self._calculate_anls(gt, pred))
- return {'accuracy': batch_accuracy, 'anls': batch_anls}
+ return {"accuracy": batch_accuracy, "anls": batch_anls}
def _preprocess_str(self, string):
if not self.case_sensitive:
@@ -80,8 +84,7 @@ def _preprocess_str(self, string):
return string.strip()
def _calculate_accuracy(self, gt, pred):
-
- if pred == 'none':
+ if pred == "none":
return 0
for gt_elm in gt:
@@ -94,7 +97,7 @@ def _calculate_anls(self, gt, pred):
if len(pred) == 0:
return 0
- if pred == 'none':
+ if pred == "none":
return 0
answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt]
@@ -102,7 +105,8 @@ def _calculate_anls(self, gt, pred):
anls = max_similarity if max_similarity >= self.anls_threshold else 0
return anls
-
-if __name__ == '__main__':
- print('-----------------')
- multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}])
\ No newline at end of file
+
+
+if __name__ == "__main__":
+ print("-----------------")
+ multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}])
diff --git a/lmms_eval/tasks/seedbench_2/utils.py b/lmms_eval/tasks/seedbench_2/utils.py
index 5ec7fccc..abcc97b8 100644
--- a/lmms_eval/tasks/seedbench_2/utils.py
+++ b/lmms_eval/tasks/seedbench_2/utils.py
@@ -4,21 +4,23 @@
def seed_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["image"]]
-def parse_choice_img(choice : str, img_token : str):
+
+def parse_choice_img(choice: str, img_token: str):
if "jpg" in choice or "png" in choice:
return img_token
return choice
-def seed_doc_to_text(doc, model_specific_kwargs = None):
+
+def seed_doc_to_text(doc, model_specific_kwargs=None):
question = doc["question"]
- question.replace("", model_specific_kwargs['img_token'])
+ question.replace("", model_specific_kwargs["img_token"])
question += "\n" + f"A. {parse_choice_img(doc['choice_a'], model_specific_kwargs['img_token'])}\n"
question += f"B. {parse_choice_img(doc['choice_b'], model_specific_kwargs['img_token'])}\n"
question += f"C. {parse_choice_img(doc['choice_c'], model_specific_kwargs['img_token'])}\n"
question += f"D. {parse_choice_img(doc['choice_d'], model_specific_kwargs['img_token'])}"
- if (doc['data_type'] == "Image Generation"):
- num_img_in_question = len(doc['image']) - 4
- prepend_tokens = [model_specific_kwargs['img_token']] * num_img_in_question
+ if doc["data_type"] == "Image Generation":
+ num_img_in_question = len(doc["image"]) - 4
+ prepend_tokens = [model_specific_kwargs["img_token"]] * num_img_in_question
question = " ".join(prepend_tokens) + "\n" + question
return f"{question}\n{model_specific_kwargs['post_prompt']}"
diff --git a/lmms_eval/tasks/stvqa/utils.py b/lmms_eval/tasks/stvqa/utils.py
index 1f106f06..bb317f69 100644
--- a/lmms_eval/tasks/stvqa/utils.py
+++ b/lmms_eval/tasks/stvqa/utils.py
@@ -1,21 +1,25 @@
import os
import json
+
def stvqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{question}{post_prompt}"
+
def stvqa_doc_to_visual(doc):
return [doc["image"].convert("RGB")]
+
def stvqa_process_results(doc, results):
answer = results[0]
return {"submission": {"question_id": int(doc["question_id"]), "answer": answer}}
+
def stvqa_aggregate_submissions(results):
os.makedirs("./submissions", exist_ok=True)
with open("./submissions/stvqa_test_for_submission.json", "w") as f:
json.dump(results, f)
- return -1
\ No newline at end of file
+ return -1