From fdcc63f7233d5b46ec712fd9cd278d734b97757d Mon Sep 17 00:00:00 2001 From: Li Bo Date: Tue, 13 Feb 2024 19:02:18 +0800 Subject: [PATCH] [Model] Add models (#47) * Refactor logging and model initialization * Fix wandb_logger.online() method call * Add error handling during evaluation * Add wait time and error handling in get_chat_response function * Update wait_time in get_chat_response function * Refactor code for improved readability and maintainability * Refactor doc_to_visual function to handle multiple images in ICON-QA tasks * Refactor logging_utils.py and utils.py This commit refactors the `logging_utils.py` and `utils.py` files. It removes unused imports, adjusts code formatting, and updates the `get_chat_response` function to increase the `wait_time` parameter from 5 to 10. * Refactor code for wandb logging and generation in OtterHD class * Refactor prepare_report_by_task method in logging_utils.py * Update generation parameters in OtterHD model * Update generation parameters in OtterHD model * Squashed commit of the following: commit 21dea7bdc123414fa79fcc9929d9c924a0ca8c17 Author: kcz358 <92624596+kcz358@users.noreply.github.com> Date: Tue Feb 13 18:50:37 2024 +0800 Fix seedbench choices bugs (#45) commit 12144a692b70f58f890e9205f044fde8803705ff Author: XinrunDu <154438029+XinrunDu@users.noreply.github.com> Date: Tue Feb 13 18:50:23 2024 +0800 add stvqa and multidocvqa (#46) commit aca1e6d05c7e410a23a5d4be2f1385e61c332b8e Author: XinrunDu <154438029+XinrunDu@users.noreply.github.com> Date: Sun Feb 11 00:54:39 2024 +0800 add cmmmu (#44) Co-authored-by: ygjin11 <1633504509@qq.com> commit 0925443560d8940a09c2114910b53016b8e71ed8 Author: kcz358 <92624596+kcz358@users.noreply.github.com> Date: Sun Feb 11 00:54:23 2024 +0800 [Feat] Add qwen loglikelihood (#43) * Add qwen loglikelihood * Revise the pyproject dependency. Move tiktoken out from optional-dependencies * Add ferret-bench * Add seedbench 2, test on llava commit 16f1cf24eeb731fb45bf19c0ec9bacb8c0330b6b Author: JvThunder <44111143+JvThunder@users.noreply.github.com> Date: Wed Feb 7 00:08:22 2024 +0800 Joshua/vizwizvqa refactor (#42) * refactor vizwizvqa task * Merge commit '9bbbad51a77051fcf676438f81e81f723c1b438b' * Fix exact_match accuracy calculation in vizwiz_vqa_process_results * Update vizwiz_vqa tasks --------- Co-authored-by: Fanyi Pu --- lmms_eval/__main__.py | 28 +-- lmms_eval/logging_utils.py | 66 +++---- .../model_utils/qwen/qwen_generate_utils.py | 90 +++------- lmms_eval/models/otterhd.py | 4 +- lmms_eval/models/qwen_vl.py | 47 ++--- lmms_eval/tasks/cmmmu/utils.py | 161 +++++++++--------- lmms_eval/tasks/dc100_en/utils.py | 10 +- lmms_eval/tasks/dc200_cn/utils.py | 10 +- lmms_eval/tasks/iconqa/iconqa_test.yaml | 23 +++ lmms_eval/tasks/iconqa/iconqa_val.yaml | 23 +++ lmms_eval/tasks/iconqa/utils.py | 62 +++++++ lmms_eval/tasks/multidocvqa/utils.py | 38 +++-- lmms_eval/tasks/seedbench_2/utils.py | 14 +- lmms_eval/tasks/stvqa/utils.py | 6 +- 14 files changed, 320 insertions(+), 262 deletions(-) create mode 100644 lmms_eval/tasks/iconqa/iconqa_test.yaml create mode 100644 lmms_eval/tasks/iconqa/iconqa_val.yaml create mode 100644 lmms_eval/tasks/iconqa/utils.py diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index a6285bcd..8fd1b4ae 100644 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -172,17 +172,21 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: is_main_process = False for args in args_list: - if is_main_process: - wandb_logger = WandbLogger(args) - results = cli_evaluate_single(args) - - accelerator.wait_for_everyone() - if is_main_process: - wandb_logger.log_eval_result(results) - if wandb_logger.online(): - wandb_logger.write_to_report(results) - wandb_logger.finish() - results_list.append(results) + try: + if is_main_process and args.wandb_args: + wandb_logger = WandbLogger(args) + results = cli_evaluate_single(args) + + accelerator.wait_for_everyone() + if is_main_process and args.wandb_args and results is not None: + wandb_logger.log_eval_result(results) + if wandb_logger.online(): + wandb_logger.write_to_report() + wandb_logger.finish() + results_list.append(results) + except Exception as e: + eval_logger.error(f"Error during evaluation: {e}") + results_list.append(None) for args, results in zip(args_list, results_list): # cli_evaluate will return none if the process is not the main process (rank 0) @@ -209,7 +213,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None: elif args.tasks == "list": eval_logger.info("Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS)))) sys.exit() - elif args.tasks == "list_tasks_num": + elif args.tasks == "list_with_num": log_message = ( "\n" + "=" * 70 + "\n" + "\n\tYou are trying to check all the numbers in each task." + "\n\tThis action will download the complete dataset." + "\n\tIf the results are not clear initially, call this again." + "\n\n" + "=" * 70 ) diff --git a/lmms_eval/logging_utils.py b/lmms_eval/logging_utils.py index 15b7cccc..f09bbaa6 100644 --- a/lmms_eval/logging_utils.py +++ b/lmms_eval/logging_utils.py @@ -71,11 +71,6 @@ def init_run(self): # initialize a W&B run self.run = wandb.init(**self.wandb_args) - # call wr inside the init_run method to avoid multiple times logging - import wandb.apis.reports as wr - - self.wr = wr - def log_eval_result(self, results): # Log configs to wandb configs = self.get_config(results) @@ -217,80 +212,71 @@ def sanitize_results_dict(self): return wandb_summary, _results def prepare_report_by_task(self, results): - task_names = list(results.get("results", {}).keys()) + import wandb.apis.reports as wr + + task_names = list(self.results.get("results", {}).keys()) blocks = [] for task_name in task_names: - blocks.append(self.wr.H2(task_name)) + blocks.append(wr.H2(task_name)) panels = [] for metric_name, metric_value in results.items(): if task_name in metric_name: panels.append( - self.wr.ScalarChart( + wr.ScalarChart( title=f"{metric_name}", metric=f"{metric_name}", font_size="large", ) ) _results = { - "results": {f"{task_name}": results.get("results").get(task_name)}, - "versions": {f"{task_name}": results.get("versions").get(task_name)}, - "n-shot": {f"{task_name}": results.get("n-shot").get(task_name)}, + "results": {f"{task_name}": self.results.get("results").get(task_name)}, + "versions": {f"{task_name}": self.results.get("versions").get(task_name)}, + "n-shot": {f"{task_name}": self.results.get("n-shot").get(task_name)}, } results_md = utils.make_table(_results) - blocks.extend([self.wr.MarkdownBlock(results_md), self.wr.PanelGrid(panels=panels)]) - # blocks.extend([ - # self.wr.WeaveBlockSummaryTable( - # project=self.run.project, - # entity=self.run.entity, - # table_name=f"{task_name}_eval_results", - # ), - # self.wr.PanelGrid( - # runsets=[ - # self.wr.Runset( - # project=self.run.project, entity=self.run.entity, - # ).set_filters_with_python_expr(f'Name == "{str(self.run.name)}"'), - # ] - # ), - # ]) + blocks.extend([wr.MarkdownBlock(results_md), wr.PanelGrid(panels=panels)]) + # TODO: Add results table return blocks - def write_to_report(self, results): - wandb_project = self.run.project - wandb_entity = self.run.entity - report = self.wr.Report( + def write_to_report(self): + import wandb.apis.reports as wr + + report = wr.Report( project=self.run.project, entity=self.run.entity, title=f"({datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) {self.run.id} - Evaluation report", description=f"Evaluation run by: {self.run.entity} logged to {self.run.url}", ) - results_md = utils.make_table(results) task_blocks = self.prepare_report_by_task(self.wandb_results) blocks = ( [ - self.wr.TableOfContents(), - self.wr.H1("Complete Evaluation Results"), - self.wr.WeaveBlockSummaryTable( + wr.TableOfContents(), + wr.H1("Complete Evaluation Results"), + wr.WeaveBlockSummaryTable( project=self.run.project, entity=self.run.entity, - table_name=f"evaluation/eval_results", + table_name="evaluation/eval_results", ), - self.wr.PanelGrid( + wr.PanelGrid( runsets=[ - self.wr.Runset( + wr.Runset( project=self.run.project, entity=self.run.entity, ).set_filters_with_python_expr(f'Name == "{str(self.run.name)}"'), ] ), - self.wr.H1("Evaluation Results By Task"), + wr.H1("Evaluation Results By Task"), ] + task_blocks + [ - self.wr.H1("Evaluation Config"), - self.wr.CodeBlock(json.dumps(self.results["config"], indent=5).split("\n"), language="json"), + wr.H1("Evaluation Config"), + wr.CodeBlock( + json.dumps(self.results["config"], indent=5).split("\n"), + language="json", + ), # TODO: Add appendix ] ) diff --git a/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py b/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py index f198c267..b012c8f2 100644 --- a/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py +++ b/lmms_eval/models/model_utils/qwen/qwen_generate_utils.py @@ -47,9 +47,7 @@ def get_ltor_masks_and_position_ids( att_mask_batch = micro_batch_size else: att_mask_batch = 1 - attention_mask = torch.tril( - torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) - ).view(att_mask_batch, 1, seq_length, seq_length) + attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) @@ -66,7 +64,6 @@ def get_ltor_masks_and_position_ids( if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(micro_batch_size): - # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. @@ -134,9 +131,7 @@ def make_context( nl_tokens = tokenizer.encode("\n") def _tokenize_str(role, content): - return f"{role}\n{content}", tokenizer.encode( - role, allowed_special=set(tokenizer.IMAGE_ST) - ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST)) + return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set(tokenizer.IMAGE_ST)) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST)) system_text, system_tokens_part = _tokenize_str("system", system) system_tokens = im_start_tokens + system_tokens_part + im_end_tokens @@ -148,22 +143,16 @@ def _tokenize_str(role, content): query_text, query_tokens_part = _tokenize_str("user", turn_query) query_tokens = im_start_tokens + query_tokens_part + im_end_tokens if turn_response is not None: - response_text, response_tokens_part = _tokenize_str( - "assistant", turn_response - ) + response_text, response_tokens_part = _tokenize_str("assistant", turn_response) response_tokens = im_start_tokens + response_tokens_part + im_end_tokens next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens - prev_chat = ( - f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" - ) + prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" else: next_context_tokens = nl_tokens + query_tokens + nl_tokens prev_chat = f"\n{im_start}{query_text}{im_end}\n" - current_context_size = ( - len(system_tokens) + len(next_context_tokens) + len(context_tokens) - ) + current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens) if current_context_size < max_window_size: context_tokens = next_context_tokens + context_tokens raw_text = prev_chat + raw_text @@ -172,16 +161,7 @@ def _tokenize_str(role, content): context_tokens = system_tokens + context_tokens raw_text = f"{im_start}{system_text}{im_end}" + raw_text - context_tokens += ( - nl_tokens - + im_start_tokens - + _tokenize_str("user", query)[1] - + im_end_tokens - + nl_tokens - + im_start_tokens - + tokenizer.encode("assistant") - + nl_tokens - ) + context_tokens += nl_tokens + im_start_tokens + _tokenize_str("user", query)[1] + im_end_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" elif chat_format == "raw": @@ -202,7 +182,7 @@ def _decode_default( raw_text_len: int, verbose: bool = False, return_end_reason: bool = False, - errors: str='replace', + errors: str = "replace", ): trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] if verbose: @@ -227,16 +207,7 @@ def _decode_default( def _decode_chatml( - tokens: List[int], - *, - stop_words: List[str], - eod_token_ids: List[int], - tokenizer: PreTrainedTokenizer, - raw_text_len: int, - context_length: int, - verbose: bool = False, - return_end_reason: bool = False, - errors: str='replace' + tokens: List[int], *, stop_words: List[str], eod_token_ids: List[int], tokenizer: PreTrainedTokenizer, raw_text_len: int, context_length: int, verbose: bool = False, return_end_reason: bool = False, errors: str = "replace" ): end_reason = f"Gen length {len(tokens)}" eod_token_idx = context_length @@ -270,7 +241,7 @@ def decode_tokens( chat_format: str, verbose: bool = False, return_end_reason: bool = False, - errors: str="replace", + errors: str = "replace", ) -> str: if torch.is_tensor(tokens): tokens = tokens.cpu().numpy().tolist() @@ -315,42 +286,19 @@ class StopWordsLogitsProcessor(LogitsProcessor): """ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): - if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: - raise ValueError( - f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." - ) + raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.") if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): - raise ValueError( - f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." - ) - if any( - any( - (not isinstance(token_id, (int, np.integer)) or token_id < 0) - for token_id in stop_word_ids - ) - for stop_word_ids in stop_words_ids - ): - raise ValueError( - f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." - ) - - self.stop_words_ids = list( - filter( - lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids - ) - ) + raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.") + if any(any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) for stop_word_ids in stop_words_ids): + raise ValueError(f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}.") + + self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids)) self.eos_token_id = eos_token_id for stop_token_seq in self.stop_words_ids: - assert ( - len(stop_token_seq) > 0 - ), "Stop words token sequences {} cannot have an empty list".format( - stop_words_ids - ) - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: + assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(stop_words_ids) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: stopped_samples = self._calc_stopped_samples(input_ids) for i, should_stop in enumerate(stopped_samples): if should_stop: @@ -416,4 +364,4 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): def switch(val1, val2, boolean): boolean = boolean.type_as(val1) - return (1 - boolean) * val1 + boolean * val2 \ No newline at end of file + return (1 - boolean) * val1 + boolean * val2 diff --git a/lmms_eval/models/otterhd.py b/lmms_eval/models/otterhd.py index e9fa4635..70d6d620 100644 --- a/lmms_eval/models/otterhd.py +++ b/lmms_eval/models/otterhd.py @@ -172,7 +172,9 @@ def _collate(x): gen_kwargs["top_p"] = None if "num_beams" not in gen_kwargs: gen_kwargs["num_beams"] = 1 - generation_output = self.model.generate(**model_inputs, max_new_tokens=gen_kwargs["max_new_tokens"], pad_token_id=self.tokenizer.eos_token_id) + generation_output = self.model.generate( + **model_inputs, temperature=gen_kwargs["temperature"], max_new_tokens=gen_kwargs["max_new_tokens"], top_p=gen_kwargs["top_p"], num_beams=gen_kwargs["num_beams"], pad_token_id=self.tokenizer.eos_token_id + ) generation_texts = self.processor.batch_decode(generation_output, skip_special_tokens=True) response = [gen_text.split("\x04")[1].strip(" ").strip("\n") for gen_text in generation_texts] res.extend(response) diff --git a/lmms_eval/models/qwen_vl.py b/lmms_eval/models/qwen_vl.py index 01334c51..c3b6c7e1 100644 --- a/lmms_eval/models/qwen_vl.py +++ b/lmms_eval/models/qwen_vl.py @@ -133,46 +133,30 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: name = uuid.uuid4().hex.upper()[0:6] visual.save(f"/tmp/{name}.png") visual_paths.append(f"/tmp/{name}.png") - query.append({ - 'image' : f"/tmp/{name}.png" - }) - + query.append({"image": f"/tmp/{name}.png"}) + # Make a copy for query to save context (text that needs to be masked) - context_query = [ _ for _ in query] - context_query.append({'text' : contexts}) - query.append({'text' : contexts + continuation}) - + context_query = [_ for _ in query] + context_query.append({"text": contexts}) + query.append({"text": contexts + continuation}) + context_query = self.tokenizer.from_list_format(context_query) query = self.tokenizer.from_list_format(query) - + raw_contxt_text, context_tokens = make_context( - self.tokenizer, - context_query, - history=None, - system='You are a helpful assistant', - max_window_size=self.model.generation_config.max_window_size, - chat_format=self.model.generation_config.chat_format - ) + self.tokenizer, context_query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format + ) context_tokens = torch.tensor([context_tokens]) - + raw_continuation_text, continuation_tokens = make_context( - self.tokenizer, - query, - history=None, - system='You are a helpful assistant', - max_window_size=self.model.generation_config.max_window_size, - chat_format=self.model.generation_config.chat_format - ) + self.tokenizer, query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format + ) continuation_tokens = torch.tensor([continuation_tokens]).to(self.model.device) attn_mask = torch.ones_like(continuation_tokens).to(self.model.device) labels = continuation_tokens.clone().to(self.model.device) - labels[:, :context_tokens.shape[1]] = -100 + labels[:, : context_tokens.shape[1]] = -100 with torch.inference_mode(): - outputs = self.model( - input_ids = continuation_tokens, - labels=labels, - attention_mask = attn_mask - ) + outputs = self.model(input_ids=continuation_tokens, labels=labels, attention_mask=attn_mask) loss = outputs.loss logits = outputs["logits"] greedy_tokens = logits.argmax(dim=-1) @@ -181,10 +165,9 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: max_equal = (greedy_tokens == cont_toks).all() res.append((float(loss.item()), bool(max_equal))) pbar.update(1) - + pbar.close() return res - assert False, "We have not implemented this function for Qwen VL yet" diff --git a/lmms_eval/tasks/cmmmu/utils.py b/lmms_eval/tasks/cmmmu/utils.py index e5899677..ca4e7b31 100644 --- a/lmms_eval/tasks/cmmmu/utils.py +++ b/lmms_eval/tasks/cmmmu/utils.py @@ -6,38 +6,39 @@ from collections import Counter PROMPT = { - 'task_instructions': [ - '请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。', - '请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。', - '请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。' + "task_instructions": [ + "请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。", + "请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。", + "请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。", ], - 'multi_choice_example_format': ['问题:{}\n选项:\n{}\n正确答案:\n'], - 'T/F_example_format': ['问题:{}\n正确答案:\n'], - 'short_ans_example_format': ['问题:{}\n正确答案:\n'], + "multi_choice_example_format": ["问题:{}\n选项:\n{}\n正确答案:\n"], + "T/F_example_format": ["问题:{}\n正确答案:\n"], + "short_ans_example_format": ["问题:{}\n正确答案:\n"], } + def construct_prompt(sample): - question = sample['question'] - task_instructions = PROMPT['task_instructions'] - - if sample['type'] == '选择': + question = sample["question"] + task_instructions = PROMPT["task_instructions"] + + if sample["type"] == "选择": formatted_options = "" - start_chr = 'A' + start_chr = "A" for i in range(1, 5): formatted_options += f"({start_chr}) {sample[f'option{i}']}\n" start_chr = chr(ord(start_chr) + 1) - - current_example_template = PROMPT['multi_choice_example_format'][0] + + current_example_template = PROMPT["multi_choice_example_format"][0] current_example = current_example_template.format(question, formatted_options) final_input_prompt = task_instructions[0] + "\n\n" + current_example - - elif sample['type'] == '判断': - current_example_template = PROMPT['T/F_example_format'][0] + + elif sample["type"] == "判断": + current_example_template = PROMPT["T/F_example_format"][0] current_example = current_example_template.format(question) final_input_prompt = task_instructions[1] + "\n\n" + current_example - + else: # For fill in the blanks questions. - current_example_template = PROMPT['short_ans_example_format'][0] + current_example_template = PROMPT["short_ans_example_format"][0] current_example = current_example_template.format(question) final_input_prompt = task_instructions[2] + "\n\n" + current_example @@ -46,9 +47,11 @@ def construct_prompt(sample): return final_input_prompt + def cmmmu_doc_to_text(doc): return construct_prompt(doc) + def cmmmu_doc_to_visual(doc): prompt = construct_prompt(doc) image_tokens = re.findall(r"<图片 \d+>", prompt) @@ -57,6 +60,7 @@ def cmmmu_doc_to_visual(doc): visual = [doc[image_token].convert("RGB") for image_token in image_tokens] return visual + def cmmmu_process_results(doc, results): pred = results[0] if doc["type"] == "选择": @@ -68,6 +72,7 @@ def cmmmu_process_results(doc, results): parsed_pred = get_fill_blank_prediction(pred, doc["answer"]) return {"cmmmu_acc": {"id": doc["id"], "subdomain": doc["subcategory"], "question_type": doc["type"], "answer": doc["answer"], "parsed_pred": parsed_pred}} + def cmmmu_aggregate_results(results): evaluation_result = {} subset_to_eval_samples = defaultdict(list) @@ -105,16 +110,18 @@ def cmmmu_aggregate_results(results): print(printable_results) return printable_results["Overall"]["acc"] + def cmmmu_process_test_results_for_submission(doc, results): response = results[0] return {"cmmmu_acc": {"id": doc["id"], "type": doc["type"], "response": response}} + def cmmmu_test_aggregate_results_for_submission(results): os.makedirs("./submissions", exist_ok=True) with open("./submissions/cmmmu_test_for_submission.jsonl", "w") as f: for result in results: json.dump(result, f, ensure_ascii=False) - f.write('\n') + f.write("\n") return -1 @@ -128,26 +135,26 @@ def cmmmu_test_aggregate_results_for_submission(results): "科学": ["生物", "化学", "地理", "数学", "物理"], "健康与医学": ["基础医学", "临床医学", "诊断学与实验室医学", "制药", "公共卫生"], "人文社会科学": ["历史", "文献学", "社会学", "心理学"], - "技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"] + "技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"], } -def eval_cmmmu(entries): +def eval_cmmmu(entries): correct_cnt = 0 for entry in entries: - parsed_pred = entry.get('parsed_pred', '') + parsed_pred = entry.get("parsed_pred", "") correct = False - if entry.get('question_type') == '选择': - if parsed_pred == entry['answer']: + if entry.get("question_type") == "选择": + if parsed_pred == entry["answer"]: correct_cnt += 1 correct = True - elif entry.get('question_type') == '填空': - norm_answers = normalize_str(entry['answer'], entry['answer']) + elif entry.get("question_type") == "填空": + norm_answers = normalize_str(entry["answer"], entry["answer"]) for pred in parsed_pred: # already normalized - if isinstance(pred, str): # if it's a string, then find if ans in the pred_i + if isinstance(pred, str): # if it's a string, then find if ans in the pred_i for norm_ans in norm_answers: # only see if the string answer in the string pred # print(norm_ans, pred) @@ -156,7 +163,7 @@ def eval_cmmmu(entries): correct_cnt += 1 correct = True break - else: # it's a number + else: # it's a number if pred in norm_answers: if not correct: correct_cnt += 1 @@ -164,9 +171,10 @@ def eval_cmmmu(entries): break else: - positive_keywords = ['正确', '对', '准确', '肯定', '对的'] - negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错'] - ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不'] + positive_keywords = ["正确", "对", "准确", "肯定", "对的"] + negative_keywords = ["不对", "错误", "不正确", "不准确", "不合适", "否定", "错的", "错"] + ambiguous_keywords = ["对错", "是否正确", "否正确", "或者", "是否", "正确性", "对不"] + def judge_similarity(pred_list, positive_keywords, negative_keywords): positive_count = 0 negative_count = 0 @@ -182,50 +190,43 @@ def judge_similarity(pred_list, positive_keywords, negative_keywords): elif negative_count > positive_count: return "错" else: - return random.choice(['对', '错']) - answer = entry['answer'] + return random.choice(["对", "错"]) + + answer = entry["answer"] parsed_pred = [word for word in parsed_pred if not any(ambiguous in word for ambiguous in ambiguous_keywords)] result = judge_similarity(parsed_pred, positive_keywords, negative_keywords) if result == answer: correct_cnt += 1 correct = True if correct: - entry['judge'] = '正确' + entry["judge"] = "正确" else: - entry['judge'] = '错误' + entry["judge"] = "错误" if len(entries) == 0: - print('entries_num == 0, please check your file') - results_count = { - 'correct_num': 0, - 'entries_num': 0, - 'acc': 0 - } + print("entries_num == 0, please check your file") + results_count = {"correct_num": 0, "entries_num": 0, "acc": 0} else: - results_count = { - 'correct_num': correct_cnt, - 'entries_num': len(entries), - 'acc': correct_cnt / len(entries) - } + results_count = {"correct_num": correct_cnt, "entries_num": len(entries), "acc": correct_cnt / len(entries)} return results_count def get_multi_choice_prediction(response, all_choices, index2ans): - for char in [',', '.', '!', '?', ';', ':',"'"]: + for char in [",", ".", "!", "?", ";", ":", "'"]: response = response.strip(char) - response = " " + response + " " # add space to avoid partial match + response = " " + response + " " # add space to avoid partial match candidates = [] for choice in all_choices: # (A) (B) (C) (D) # Add the choice to candidates each time it appears in the response - candidates.extend([choice for _ in range(response.count(f'({choice})'))]) + candidates.extend([choice for _ in range(response.count(f"({choice})"))]) if len(candidates) == 0: for choice in all_choices: # A B C D # Similarly, add the choice for each occurrence - candidates.extend([choice for _ in range(response.count(f'{choice}'))]) + candidates.extend([choice for _ in range(response.count(f"{choice}"))]) if len(candidates) == 0 and len(response.split()) >= 1: for index, ans in index2ans.items(): @@ -237,7 +238,7 @@ def get_multi_choice_prediction(response, all_choices, index2ans): for index, ans in index2ans.items(): if ans in response: candidates.append(index) - index_ans = False # it's content ans. + index_ans = False # it's content ans. if len(candidates) == 0: # still not get answer, randomly choose one. return random.choice(all_choices) @@ -251,15 +252,16 @@ def get_multi_choice_prediction(response, all_choices, index2ans): most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count] # Combine the most frequent candidates in ABCD order - return ''.join(most_frequent_candidates) + return "".join(most_frequent_candidates) + def extract_numbers(string): # Pattern for numbers with Chinese commas - pattern_commas = r'-?\d{1,3}(?:,\d{3})+' + pattern_commas = r"-?\d{1,3}(?:,\d{3})+" # Pattern for scientific notation - pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' + pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" # Pattern for simple numbers without Chinese commas - pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)' + pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)" # Extract numbers with Chinese commas numbers_with_commas = re.findall(pattern_commas, string) @@ -272,16 +274,19 @@ def extract_numbers(string): all_numbers = numbers_with_commas + numbers_scientific + numbers_simple return all_numbers + def check_is_number(string): try: - float(string.replace(',', '')) + float(string.replace(",", "")) return True except ValueError: # check if there's comma inside return False + def count_letters(string): - return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string) + return sum(c.isalpha() and "a" <= c <= "z" or "A" <= c <= "Z" for c in string) + def normalize_str(string, answer): # check if characters in the string @@ -294,16 +299,17 @@ def normalize_str(string, answer): is_number = check_is_number(string) if is_number: - string = string.replace(',', '') + string = string.replace(",", "") string = float(string) # leave 2 decimal string = round(string, 2) return [string] - else: # it's likely to be a string + else: # it's likely to be a string if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2: return [] return [string] + def get_fill_blank_prediction(response, answer): """get the prediction from the generated response, return a list of predicted strings or numbers""" @@ -311,15 +317,14 @@ def get_fill_blank_prediction(response, answer): def get_key_subresponses(response): key_responses = [] response = response.strip("。").strip() - sub_responses = re.split(r'。|\n', response) - indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择', - '正确答案', '因此', '最后', '答案', '结果'] + sub_responses = re.split(r"。|\n", response) + indicators_of_keys = ["是", "为", "所以", "等于", "方案", "选择", "正确答案", "因此", "最后", "答案", "结果"] key_responses = [] for index, resp in enumerate(sub_responses): # if last one, accept it's an equation (the entire response can be just one sentence with equation) if index == len(sub_responses) - 1: - indicators_of_keys.extend(['=']) - shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) + indicators_of_keys.extend(["="]) + shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: @@ -327,18 +332,18 @@ def get_key_subresponses(response): else: if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): shortest_key_response = resp.split(indicator)[-1].strip() - + if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]: key_responses.append(shortest_key_response) - if len(key_responses) == 0: # did not found any + if len(key_responses) == 0: # did not found any return [response] return key_responses key_responses = get_key_subresponses(response) - pred_list = key_responses.copy() # keep the original string response + pred_list = key_responses.copy() # keep the original string response for resp in key_responses: pred_list.extend(extract_numbers(resp)) @@ -352,6 +357,7 @@ def get_key_subresponses(response): return pred_list + def get_TF_prediction(response): """get the prediction from the generated response, return a list of predicted strings or numbers""" @@ -359,12 +365,11 @@ def get_TF_prediction(response): def get_key_subresponses(response): key_responses = [] response = response.strip("。").strip() - sub_responses = re.split(r'。|\n', response) - indicators_of_keys = ['是', '为', '所以', '判断', - '陈述', '说法', '表达', '答案', '结果'] + sub_responses = re.split(r"。|\n", response) + indicators_of_keys = ["是", "为", "所以", "判断", "陈述", "说法", "表达", "答案", "结果"] key_responses = [] for index, resp in enumerate(sub_responses): - shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) + shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: @@ -372,25 +377,25 @@ def get_key_subresponses(response): else: if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): shortest_key_response = resp.split(indicator)[-1].strip() - + if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]: key_responses.append(shortest_key_response) - if len(key_responses) == 0: # did not found any + if len(key_responses) == 0: # did not found any return [response] return key_responses key_responses = get_key_subresponses(response) - pred_list = key_responses.copy() # keep the original string response + pred_list = key_responses.copy() # keep the original string response # remove duplicates pred_list = list(set(pred_list)) return pred_list -def get_multi_choice_info(options): +def get_multi_choice_info(options): start_chr = "A" all_choices = [] index2ans = {} @@ -400,8 +405,8 @@ def get_multi_choice_info(options): return index2ans, all_choices -def calculate_ins_level_acc(results): +def calculate_ins_level_acc(results): correct_sum = 0 entries_sum = 0 for cat_results in results.values(): @@ -409,4 +414,4 @@ def calculate_ins_level_acc(results): entries_sum += cat_results["entries_num"] if entries_sum == 0: return 0 - return correct_sum / entries_sum \ No newline at end of file + return correct_sum / entries_sum diff --git a/lmms_eval/tasks/dc100_en/utils.py b/lmms_eval/tasks/dc100_en/utils.py index 318f4c21..d3b47ab7 100644 --- a/lmms_eval/tasks/dc100_en/utils.py +++ b/lmms_eval/tasks/dc100_en/utils.py @@ -2,6 +2,7 @@ import requests import re import logging +import time import os import yaml from pathlib import Path @@ -36,7 +37,7 @@ def doc_to_visual(doc): Provide a few lines for explanation and the rate number at last after "Final Score:".""" -def get_chat_response(base64_image, prompt, max_retries=3): +def get_chat_response(base64_image, prompt, max_retries=5, wait_time=10): headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json", @@ -68,8 +69,13 @@ def get_chat_response(base64_image, prompt, max_retries=3): return response_data["choices"][0]["message"]["content"] except requests.exceptions.RequestException as e: eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}") + time.sleep(wait_time) if attempt == max_retries - 1: - raise + eval_logger.error(f"Failed to get response after {max_retries} attempts") + return "" + except Exception as e: + eval_logger.error(f"Error on attempt {attempt+1}: {e}") + return "" def image_to_base64(pil_image): diff --git a/lmms_eval/tasks/dc200_cn/utils.py b/lmms_eval/tasks/dc200_cn/utils.py index bd00f4b5..1e729150 100644 --- a/lmms_eval/tasks/dc200_cn/utils.py +++ b/lmms_eval/tasks/dc200_cn/utils.py @@ -6,6 +6,7 @@ import yaml from pathlib import Path from io import BytesIO +import time def doc_to_visual(doc): @@ -36,7 +37,7 @@ def doc_to_visual(doc): Provide a few lines for explanation and the rate number at last after "Final Score:".""" -def get_chat_response(base64_image, prompt, max_retries=3): +def get_chat_response(base64_image, prompt, max_retries=5, wait_time=10): headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json", @@ -68,8 +69,13 @@ def get_chat_response(base64_image, prompt, max_retries=3): return response_data["choices"][0]["message"]["content"] except requests.exceptions.RequestException as e: eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}") + time.sleep(wait_time) if attempt == max_retries - 1: - raise + eval_logger.error(f"Failed to get response after {max_retries} attempts") + return "" + except Exception as e: + eval_logger.error(f"Error on attempt {attempt+1}: {e}") + return "" def image_to_base64(pil_image): diff --git a/lmms_eval/tasks/iconqa/iconqa_test.yaml b/lmms_eval/tasks/iconqa/iconqa_test.yaml new file mode 100644 index 00000000..58aa9e70 --- /dev/null +++ b/lmms_eval/tasks/iconqa/iconqa_test.yaml @@ -0,0 +1,23 @@ +dataset_path: lmms-lab/ICON-QA +dataset_kwargs: + token: True +task: "iconqa_test" +test_split: test +output_type: generate_until +doc_to_visual: !function utils.doc_to_visual +doc_to_text: !function utils.doc_to_text +doc_to_target: "answers" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +metric_list: + - metric: anls + aggregation: mean + higher_is_better: true +model_specific_prompt_kwargs: + default: + pre_prompt: "" + statement: "Given a set of images and a question, please provide the answer to the question.\n" + options_statement: "Question: {question}.\nOptions:\n{options}\nPlease answer with the option letter from the given choices directly." + freeform_statement: "Question: {question}.\nPlease answer the question using a single word or phrase." \ No newline at end of file diff --git a/lmms_eval/tasks/iconqa/iconqa_val.yaml b/lmms_eval/tasks/iconqa/iconqa_val.yaml new file mode 100644 index 00000000..7215715e --- /dev/null +++ b/lmms_eval/tasks/iconqa/iconqa_val.yaml @@ -0,0 +1,23 @@ +dataset_path: lmms-lab/ICON-QA +dataset_kwargs: + token: True +task: "iconqa_val" +test_split: val +output_type: generate_until +doc_to_visual: !function utils.doc_to_visual +doc_to_text: !function utils.doc_to_text +doc_to_target: "answers" +generation_kwargs: + max_new_tokens: 32 + temperature: 0 + do_sample: False +metric_list: + - metric: anls + aggregation: mean + higher_is_better: true +model_specific_prompt_kwargs: + default: + pre_prompt: "" + statement: "Given a set of images and a question, please provide the answer to the question.\n" + options_statement: "Question: {question}.\nOptions:\n{options}\nPlease answer with the option letter from the given choices directly." + freeform_statement: "Question: {question}.\nPlease answer the question using a single word or phrase." \ No newline at end of file diff --git a/lmms_eval/tasks/iconqa/utils.py b/lmms_eval/tasks/iconqa/utils.py new file mode 100644 index 00000000..5d833fec --- /dev/null +++ b/lmms_eval/tasks/iconqa/utils.py @@ -0,0 +1,62 @@ +import json +import os + + +def options_to_str(options_prompt): + option_prompt_str = "" + for i, option in enumerate(options_prompt): + option_choice = chr(ord("A") + i) + option_prompt_str += f"{option_choice}. {option}\n" + + option_prompt_str = option_prompt_str.rstrip("\n") + return option_prompt_str + + +def doc_to_visual(doc): + image_list = [] + if "query_image" in doc: + image_list.append(doc["query_image"].convert("RGB")) + if "choice_image_0" in doc: + image_list.append(doc["choice_image_0"].convert("RGB")) + if "choice_image_1" in doc: + image_list.append(doc["choice_image_1"].convert("RGB")) + if "choice_image_2" in doc: + image_list.append(doc["choice_image_2"].convert("RGB")) + if "choice_image_3" in doc: + image_list.append(doc["choice_image_3"].convert("RGB")) + if "choice_image_4" in doc: + image_list.append(doc["choice_image_4"].convert("RGB")) + assert len(image_list) < 6, "Maximum 5 images allowed for ICON-QA" + return image_list + + +def doc_to_text(doc, model_specific_prompt_kwargs): + question = doc["question"] + ques_type = doc["ques_type"] + options_prompt = [] + + if ques_type == "choose_img": + options_prompt.append("The first image.") + options_prompt.append("The second image.") + + options_str = options_to_str(options_prompt) + full_prompt = f"{model_specific_prompt_kwargs['pre_prompt']}{model_specific_prompt_kwargs['statement']}{model_specific_prompt_kwargs['options_statement'].format(question=question, options=options_str)}" + + elif ques_type == "choose_txt": + choices = doc["choices"].split(",") + for i, choice in enumerate(choices): + options_prompt.append(f"{choice}") + + options_str = options_to_str(options_prompt) + full_prompt = f"{model_specific_prompt_kwargs['pre_prompt']}{model_specific_prompt_kwargs['statement']}{model_specific_prompt_kwargs['options_statement'].format(question=question, options=options_str)}" + + elif ques_type == "fill_in_blank": + full_prompt = f"{model_specific_prompt_kwargs['pre_prompt']}{model_specific_prompt_kwargs['statement']}{model_specific_prompt_kwargs['freeform_statement'].format(question=question)}" + + return full_prompt + + +def test_process_results(doc, results): + pred = results[0] + questionId = doc["questionId"] + return {"anls": {"questionId": int(questionId), "answer": pred}} diff --git a/lmms_eval/tasks/multidocvqa/utils.py b/lmms_eval/tasks/multidocvqa/utils.py index 33241cf3..698e25cc 100644 --- a/lmms_eval/tasks/multidocvqa/utils.py +++ b/lmms_eval/tasks/multidocvqa/utils.py @@ -12,36 +12,41 @@ def multidocvqa_doc_to_text(doc, model_specific_prompt_kwargs): return f"{pre_prompt}{question}{post_prompt}" + def multidocvqa_doc_to_visual(doc): - return [doc[f'image_{i}'].convert("RGB") for i in range(1, 21) if doc[f'image_{i}'] is not None] + return [doc[f"image_{i}"].convert("RGB") for i in range(1, 21) if doc[f"image_{i}"] is not None] + def multidocvqa_process_results(doc, results): pred_answer = results[0] - answer = ast.literal_eval(doc['answers']) + answer = ast.literal_eval(doc["answers"]) + + return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}, "accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}} - return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}, - "accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}} def multidocvqa_aggregate_results_anls(results): keys = {k for result in results for k in result} results = {key: [result.get(key, None) for result in results] for key in keys} evaluator = Evaluator(case_sensitive=False) - metric = evaluator.get_metrics(results['answer'], results['pred_answer']) + metric = evaluator.get_metrics(results["answer"], results["pred_answer"]) + + return sum(metric["anls"]) / len(metric["anls"]) - return sum(metric['anls']) / len(metric['anls']) def multidocvqa_aggregate_results_accuracy(results): keys = {k for result in results for k in result} results = {key: [result.get(key, None) for result in results] for key in keys} evaluator = Evaluator(case_sensitive=False) - metric = evaluator.get_metrics(results['answer'], results['pred_answer']) + metric = evaluator.get_metrics(results["answer"], results["pred_answer"]) + + return sum(metric["accuracy"]) / len(metric["accuracy"]) - return sum(metric['accuracy']) / len(metric['accuracy']) def multidocvqa_process_test_results_for_submission(doc, results): answer = results[0] return {"submission": {"questionId": int(doc["questionId"]), "answer": answer, "answer_page": None}} + def multidocvqa_test_aggregate_results_for_submission(results): os.makedirs("./submissions", exist_ok=True) with open("./submissions/multidocvqa_test_for_submission.json", "w") as f: @@ -56,7 +61,6 @@ def multidocvqa_test_aggregate_results_for_submission(results): class Evaluator: def __init__(self, case_sensitive=False): - self.case_sensitive = case_sensitive self.get_edit_distance = levenshtein_distance self.anls_threshold = 0.5 @@ -71,7 +75,7 @@ def get_metrics(self, gt_answers, preds): batch_accuracy.append(self._calculate_accuracy(gt, pred)) batch_anls.append(self._calculate_anls(gt, pred)) - return {'accuracy': batch_accuracy, 'anls': batch_anls} + return {"accuracy": batch_accuracy, "anls": batch_anls} def _preprocess_str(self, string): if not self.case_sensitive: @@ -80,8 +84,7 @@ def _preprocess_str(self, string): return string.strip() def _calculate_accuracy(self, gt, pred): - - if pred == 'none': + if pred == "none": return 0 for gt_elm in gt: @@ -94,7 +97,7 @@ def _calculate_anls(self, gt, pred): if len(pred) == 0: return 0 - if pred == 'none': + if pred == "none": return 0 answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt] @@ -102,7 +105,8 @@ def _calculate_anls(self, gt, pred): anls = max_similarity if max_similarity >= self.anls_threshold else 0 return anls - -if __name__ == '__main__': - print('-----------------') - multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}]) \ No newline at end of file + + +if __name__ == "__main__": + print("-----------------") + multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}]) diff --git a/lmms_eval/tasks/seedbench_2/utils.py b/lmms_eval/tasks/seedbench_2/utils.py index 5ec7fccc..abcc97b8 100644 --- a/lmms_eval/tasks/seedbench_2/utils.py +++ b/lmms_eval/tasks/seedbench_2/utils.py @@ -4,21 +4,23 @@ def seed_doc_to_visual(doc): return [image.convert("RGB") for image in doc["image"]] -def parse_choice_img(choice : str, img_token : str): + +def parse_choice_img(choice: str, img_token: str): if "jpg" in choice or "png" in choice: return img_token return choice -def seed_doc_to_text(doc, model_specific_kwargs = None): + +def seed_doc_to_text(doc, model_specific_kwargs=None): question = doc["question"] - question.replace("", model_specific_kwargs['img_token']) + question.replace("", model_specific_kwargs["img_token"]) question += "\n" + f"A. {parse_choice_img(doc['choice_a'], model_specific_kwargs['img_token'])}\n" question += f"B. {parse_choice_img(doc['choice_b'], model_specific_kwargs['img_token'])}\n" question += f"C. {parse_choice_img(doc['choice_c'], model_specific_kwargs['img_token'])}\n" question += f"D. {parse_choice_img(doc['choice_d'], model_specific_kwargs['img_token'])}" - if (doc['data_type'] == "Image Generation"): - num_img_in_question = len(doc['image']) - 4 - prepend_tokens = [model_specific_kwargs['img_token']] * num_img_in_question + if doc["data_type"] == "Image Generation": + num_img_in_question = len(doc["image"]) - 4 + prepend_tokens = [model_specific_kwargs["img_token"]] * num_img_in_question question = " ".join(prepend_tokens) + "\n" + question return f"{question}\n{model_specific_kwargs['post_prompt']}" diff --git a/lmms_eval/tasks/stvqa/utils.py b/lmms_eval/tasks/stvqa/utils.py index 1f106f06..bb317f69 100644 --- a/lmms_eval/tasks/stvqa/utils.py +++ b/lmms_eval/tasks/stvqa/utils.py @@ -1,21 +1,25 @@ import os import json + def stvqa_doc_to_text(doc, model_specific_prompt_kwargs): question = doc["question"] pre_prompt = model_specific_prompt_kwargs["pre_prompt"] post_prompt = model_specific_prompt_kwargs["post_prompt"] return f"{pre_prompt}{question}{post_prompt}" + def stvqa_doc_to_visual(doc): return [doc["image"].convert("RGB")] + def stvqa_process_results(doc, results): answer = results[0] return {"submission": {"question_id": int(doc["question_id"]), "answer": answer}} + def stvqa_aggregate_submissions(results): os.makedirs("./submissions", exist_ok=True) with open("./submissions/stvqa_test_for_submission.json", "w") as f: json.dump(results, f) - return -1 \ No newline at end of file + return -1