From 81297da5bf09dda534511efbaf184734545d2b82 Mon Sep 17 00:00:00 2001 From: Li Bo Date: Tue, 24 Sep 2024 20:56:53 +0800 Subject: [PATCH] [feat] support video evaluation for qwen2-vl and add mix-evals-video2text (#275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add new ouput_path saving logic and add evaluation tracker to manage samples saving process * add: regression test * add: regression test * clean: unuseful code * ๐Ÿšซ Remove unused import for cleaner code Eliminated the commented-out import statement for WandbLogger to tidy up the code and enhance readability. This helps maintain focus on active components and prevents confusion over unused code. A cleaner structure contributes to better maintainability in the long run. No functional changes were made, just a step towards a more streamlined codebase. * [task] add mix_evals for video evaluation * Merge branch 'origin/main' * โœจ Improve model name sanitization for Hugging Face formats * ๐Ÿงน Refactor settings for Llava OneVision model * โœจ Enhance video and image processing capabilities - Integrated vision processing for videos and images, improving context handling within the model. - Added error logging for missing utility dependencies to inform users about installation requirements. - Updated YAML configuration to standardize prompt handling for various video tasks. - Bumped version number to indicate ongoing development status. These changes streamline how visuals are managed in the model, contributing to better assistant responses in tasks involving media. * ๐ŸŽ‰ Enhance W&B logging and video playback - Added automatic naming for W&B runs if not specified, improving organization. - Updated video frame rate from 1.0 to 0.5 for better performance and resource management during visual content processing. - Streamlined W&B logging by removing redundant code, ensuring cleaner execution flow. These changes optimize logging efficiency and enhance the overall user experience. * โœจ Refine conversation logic and adjust token limits - Updated chat template logic for better formatting in responses, ensuring consistent handling of user and assistant roles. - Reduced maximum new tokens in multiple evaluation files to ensure more concise outputs and improve efficiency. - Enhanced clarity in few-shot tasks by explicitly labeling question and answer roles in generated text. - Simplified logging of contextual and target information during evaluation, ensuring better tracking of results. These adjustments improve the overall output quality and streamline the evaluation processes. * feat: change qwen2 vl video reading to 0.25 fps to avoid oom * ๐ŸŽฅ Update video message structure in Qwen2_VL * Update qwen2_vl.py --- lmms_eval/__main__.py | 18 +- lmms_eval/models/llava_onevision.py | 12 - lmms_eval/models/qwen2_vl.py | 66 ++-- .../tasks/mix_evals/_default_template_yaml | 16 + .../tasks/mix_evals/mix_evals_video2text.yaml | 5 + .../mix_evals_video2text_freeform.yaml | 25 ++ .../mix_evals/mix_evals_video2text_mc.yaml | 34 +++ .../mix_evals_video2text_openended.yaml | 22 ++ lmms_eval/tasks/mix_evals/utils.py | 286 ++++++++++++++++++ lmms_eval/utils.py | 13 +- pyproject.toml | 7 +- 11 files changed, 458 insertions(+), 46 deletions(-) create mode 100644 lmms_eval/tasks/mix_evals/_default_template_yaml create mode 100644 lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml create mode 100644 lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml create mode 100644 lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml create mode 100644 lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml create mode 100644 lmms_eval/tasks/mix_evals/utils.py diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index f03d82ae..a50e1ce9 100755 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -282,6 +282,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: sys.exit(1) if args.wandb_args: + if "name" not in args.wandb_args: + name = f"{args.model}_{args.model_args}_{utils.get_datetime_str(timezone=args.timezone)}" + name = utils.sanitize_long_string(name) + args.wandb_args += f",name={name}" wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args)) # reset logger @@ -506,16 +510,6 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None: batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) - # Add W&B logging - if args.wandb_args: - try: - wandb_logger.post_init(results) - wandb_logger.log_eval_result() - if args.log_samples: - wandb_logger.log_eval_samples(samples) - except Exception as e: - eval_logger.info(f"Logging to Weights and Biases failed due to {e}") - evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None, datetime_str=datetime_str) if args.log_samples: @@ -525,10 +519,6 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None: if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub: evaluation_tracker.recreate_metadata_card() - if args.wandb_args: - # Tear down wandb run once all the logging is done. - wandb_logger.run.finish() - return results, samples return None, None diff --git a/lmms_eval/models/llava_onevision.py b/lmms_eval/models/llava_onevision.py index 130fa7bc..ce2c9c2d 100644 --- a/lmms_eval/models/llava_onevision.py +++ b/lmms_eval/models/llava_onevision.py @@ -126,18 +126,6 @@ def __init__( overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode cfg_pretrained = AutoConfig.from_pretrained(self.pretrained) - if cfg_pretrained.architectures[0] == "LlavaLlamaForCausalLM": # Ugly code, only used in vicuna that needs ROPE - if "224" in cfg_pretrained.mm_vision_tower: - least_token_number = self.max_frames_num * (16 // self.mm_spatial_pool_stride) ** 2 + 1000 - else: - least_token_number = self.max_frames_num * (24 // self.mm_spatial_pool_stride) ** 2 + 1000 - - scaling_factor = math.ceil(least_token_number / 4096) - if scaling_factor >= 2: - overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"} - overwrite_config["max_sequence_length"] = 4096 * scaling_factor - overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor - llava_model_args["overwrite_config"] = overwrite_config try: # Try to load the model with the multimodal argument diff --git a/lmms_eval/models/qwen2_vl.py b/lmms_eval/models/qwen2_vl.py index 40410d44..4705a6a3 100755 --- a/lmms_eval/models/qwen2_vl.py +++ b/lmms_eval/models/qwen2_vl.py @@ -1,8 +1,12 @@ +import base64 +from io import BytesIO from typing import List, Optional, Tuple, Union +import decord import torch from accelerate import Accelerator, DistributedType from loguru import logger as eval_logger +from PIL import Image from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration @@ -11,6 +15,11 @@ from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model +try: + from qwen_vl_utils import process_vision_info +except ImportError: + eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`") + @register_model("qwen2_vl") class Qwen2_VL(lmms): @@ -176,30 +185,54 @@ def _collate(x): contexts[i] = contexts[i].replace("", "") messages = [] - - if len(visuals) == 0: - for context in contexts: - message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "text", "text": context}]}] - messages.append(message) - else: - for _, context in zip(visuals, contexts): - message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": context}]}] - messages.append(message) + processed_visuals = [] + for i, context in enumerate(contexts): + if "" in context: + context = context.replace("", "") + + message = [{"role": "system", "content": "You are a helpful assistant."}] + + if len(visuals) > 0: + visual = visuals[i] if i < len(visuals) else None + if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file + vr = decord.VideoReader(visual) + first_frame = vr[0].asnumpy() + height, width = first_frame.shape[:2] + max_pixels = height * width + message.append({"role": "user", "content": [{"type": "video", "video": visual, "max_pixels": max_pixels}, {"type": "text", "text": context}]}) + elif isinstance(visual, Image.Image): # Single image + base64_image = visual.convert("RGB") + buffer = BytesIO() + base64_image.save(buffer, format="JPEG") + base64_bytes = base64.b64encode(buffer.getvalue()) + base64_string = base64_bytes.decode("utf-8") + message.append({"role": "user", "content": [{"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"}, {"type": "text", "text": context}]}) + elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images + image_content = [] + for v in visual: + base64_image = v.convert("RGB") + buffer = BytesIO() + base64_image.save(buffer, format="JPEG") + base64_bytes = base64.b64encode(buffer.getvalue()) + base64_string = base64_bytes.decode("utf-8") + image_content.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"}) + message.append({"role": "user", "content": image_content + [{"type": "text", "text": context}]}) + else: + message.append({"role": "user", "content": [{"type": "text", "text": context}]}) + else: + message.append({"role": "user", "content": [{"type": "text", "text": context}]}) + + messages.append(message) texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] - inputs = self.processor(text=texts, images=[visuals], padding=True, return_tensors="pt") + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") if self.device_map == "auto": inputs = inputs.to("cuda") else: inputs = inputs.to(self.device) - # preconfigure gen_kwargs with defaults - if "image_sizes" not in gen_kwargs: - try: - gen_kwargs["image_sizes"] = [visuals[0].size] - except: - gen_kwargs["image_sizes"] = None if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 128 if "temperature" not in gen_kwargs: @@ -221,7 +254,6 @@ def _collate(x): num_beams=gen_kwargs["num_beams"], max_new_tokens=gen_kwargs["max_new_tokens"], use_cache=self.use_cache, - # kwargs=gen_kwargs ) generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)] diff --git a/lmms_eval/tasks/mix_evals/_default_template_yaml b/lmms_eval/tasks/mix_evals/_default_template_yaml new file mode 100644 index 00000000..bda3f8e8 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/_default_template_yaml @@ -0,0 +1,16 @@ +dataset_kwargs: + cache_dir: mix_evals_video2text + token: true + video: true +dataset_path: lmms-lab/MixEvals_Video2Text +lmms_eval_specific_kwargs: + default: + post_prompt: "" + pre_prompt: "" + gpt4v: + post_prompt: "" + pre_prompt: These are frames from a video. Please answer the following questions about the video. +metadata: + gpt_eval_model_name: gpt-4o-mini + modality: video + version: 0 diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml b/lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml new file mode 100644 index 00000000..e49612a8 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml @@ -0,0 +1,5 @@ +group: mix_evals_video2text +task: +# - mix_evals_video2text_openconv +- mix_evals_video2text_mc +- mix_evals_video2text_freeform \ No newline at end of file diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml b/lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml new file mode 100644 index 00000000..e4495b50 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml @@ -0,0 +1,25 @@ +dataset_name: "video2text_closeended_free-form" +task: "mix_evals_video2text_freeform" +test_split: test +output_type: generate_until +doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual +doc_to_text: !function utils.mix_evals_video2text_doc_to_text +doc_to_target: "{{target}}" +process_results: !function utils.mix_evals_video2text_process_results_freeform +metric_list: + - metric: gpt_eval + aggregation: !function utils.mix_evals_video2text_gpt_eval + higher_is_better: true + +generation_kwargs: + max_new_tokens: 16 + +include: _default_template_yaml + +lmms_eval_specific_kwargs: + default: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "Answer the question using a single word or phrase." + gpt4v: + pre_prompt: "These are frames from a video. Please answer the following questions about the video with a short phrase." + post_prompt: "" diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml b/lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml new file mode 100644 index 00000000..fcca0731 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml @@ -0,0 +1,34 @@ +include: _default_template_yaml +dataset_name: "video2text_closeended_multiple-choice" +task: "mix_evals_video2text_mc" +test_split: test +output_type: generate_until +doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual +doc_to_text: !function utils.mix_evals_video2text_doc_to_text +doc_to_target: "{{target}}" + +generation_kwargs: + max_new_tokens: 5 + +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + +filter_list: + - name: "flexible-extract" + filter: + - function: !function utils.MultiChoiceRegexFilter + group_select: 0 + ignore_case: true + ignore_punctuation: true + +lmms_eval_specific_kwargs: + default: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "Answer with the option's letter from the given choices directly." + gpt4v: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "Answer with the option's letter from the given choices directly." diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml b/lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml new file mode 100644 index 00000000..a62b2818 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml @@ -0,0 +1,22 @@ +include: _default_template_yaml +dataset_name: "video2text_openended" +task: "mix_evals_video2text_openconv" +test_split: test +output_type: generate_until +doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual +doc_to_text: !function utils.mix_evals_video2text_doc_to_text_open_convs +doc_to_target: "" +process_results: !function utils.mix_evals_video2text_process_results_open_convs + +metric_list: + - metric: submission + aggregation: !function utils.mix_evals_video2text_aggregate_gen + higher_is_better: true + +lmms_eval_specific_kwargs: + default: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "" + gpt4v: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "" diff --git a/lmms_eval/tasks/mix_evals/utils.py b/lmms_eval/tasks/mix_evals/utils.py new file mode 100644 index 00000000..71a7ce3f --- /dev/null +++ b/lmms_eval/tasks/mix_evals/utils.py @@ -0,0 +1,286 @@ +import datetime +import json +import os +import re +import sys +import time +from pathlib import Path + +import requests +import yaml +from loguru import logger as eval_logger + +import lmms_eval.tasks._task_utils.file_utils as file_utils +from lmms_eval.filters.extraction import ExtendedRegexFilter + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +NUM_SECONDS_TO_SLEEP = 5 +GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] +API_TYPE = os.getenv("API_TYPE", "openai") + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + headers = { + "api-key": API_KEY, + "Content-Type": "application/json", + } + +eval_prompt = """You are an AI assistant who will help me to evaluate the quality of a model response to a few candidate ground truth answers. + +Some criterion +- Response that perfectly reflect the meaning of the ground truth: 1 point +- Response that reflect none of the key points in the ground truth: 0 point +- Some part in the response are correct but some parts in the ground truth are not mentioned in the response: 0.5 point +- Some part in the response are correct but other parts in the response are not mentioned in the ground truth: 0.5 point + +Here're some examples about the scoring criterion and format: +model response: Steam Cleaning Services +ground truth: ["steam clean", "steam clean", "cleaning", "car", "steam clean"], +Point: 1 + +model response: A cowboy action shooter. +ground truth: ["man"] +Point: 1 + +model response: I'm sorry, but I can't assist with that request. +ground truth: ["quality"] +Point: 0 + +Let's begin this task: +model response: {model_response} +ground truth: {ground_truth} +Point:""" + + +def get_eval(model_response: str, ground_truth: str, max_tokens: int, retries: int = 5): + global headers + content = eval_prompt.format(model_response=model_response, ground_truth=ground_truth) + + messages = [ + {"role": "user", "content": content}, + ] + + payload = { + "model": GPT_EVAL_MODEL_NAME, + "messages": messages, + "temperature": 0.2, + "max_tokens": max_tokens, + } + + for attempt in range(retries): + try: + response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + response.raise_for_status() + response_data = response.json() + + content = response_data["choices"][0]["message"]["content"].strip() + if content != "": + return content, response_data["model"] + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}") + if attempt < retries: # If we have retries left, sleep and then continue to next attempt + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {retries} attempts failed. Last error message: {e}") + return "", "" + return "", "" + + +# A bit ugly here +# But the idea is that we will unzip all the zip files +# To HF HOME cache dir +# And load it here +HF_HOME = os.environ["HF_HOME"] +cache_dir = config["dataset_kwargs"]["cache_dir"] +cache_dir = os.path.join(HF_HOME, cache_dir) +cache_dir = os.path.join(cache_dir) + + +# Pass in video path here +# Can only work correctly with video llm +def mix_evals_video2text_doc_to_visual(doc): + video_path = doc["video_path"] + video_path = os.path.join(cache_dir, video_path) + if os.path.exists(video_path): + video_path = video_path + elif os.path.exists(video_path.replace("mp4", "MP4")): + video_path = video_path.replace("mp4", "MP4") + else: + sys.exit(f"video path:{video_path} does not exist, please check") + return [video_path] + + +# This is the place where you format your question +def mix_evals_video2text_doc_to_text(doc, model_specific_prompt_kwargs=None): + if model_specific_prompt_kwargs is None: + model_specific_prompt_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in model_specific_prompt_kwargs: + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + if "post_prompt" in model_specific_prompt_kwargs: + post_prompt = model_specific_prompt_kwargs["post_prompt"] + + user_prompt = doc["prompt"] + + if "options" in doc: + option_prompt = "Here are the options:\n" + for idx, option in enumerate(doc["options"]): + char_idx = chr(ord("A") + idx) + option = option.strip() + option_prompt += f"{char_idx}. {option}\n" + + option_prompt = option_prompt.rstrip("\n") + user_prompt = f"{user_prompt}\n{option_prompt}" + + if pre_prompt: + user_prompt = f"{pre_prompt}\n{user_prompt}" + + if post_prompt: + user_prompt = f"{user_prompt}\n{post_prompt}" + return user_prompt + + +OPEN_CONVS_PROMPT = """{PRE} +{FIRST} +{POST} +""" + + +def mix_evals_video2text_doc_to_text_open_convs(doc, model_specific_prompt_kwargs=None): + if model_specific_prompt_kwargs is None: + model_specific_prompt_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in model_specific_prompt_kwargs: + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + if "post_prompt" in model_specific_prompt_kwargs: + post_prompt = model_specific_prompt_kwargs["post_prompt"] + + filtered_first_turn = re.sub(r"", "", doc["first_turn_user_prompt"]) + return OPEN_CONVS_PROMPT.format( + PRE=pre_prompt, + POST=post_prompt, + FIRST=filtered_first_turn, + ) + + +MODEL_CONVS_PROMPT = """{FIRST} +{MODEL_RESPONSE} +{PRE} +{SECOND} +{POST} +""" + + +def mix_evals_video2text_doc_to_text_open_2nd_convs(doc, model_specific_prompt_kwargs=None): + if model_specific_prompt_kwargs is None: + model_specific_prompt_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in model_specific_prompt_kwargs: + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + if "post_prompt" in model_specific_prompt_kwargs: + post_prompt = model_specific_prompt_kwargs["post_prompt"] + + return MODEL_CONVS_PROMPT.format( + PRE=pre_prompt, + POST=post_prompt, + FIRST=doc["first_turn_user_prompt"], + SECOND=doc["second_turn_user_prompt"], + MODEL_RESPONSE=doc["model_response"], + ) + + +def mix_evals_video2text_process_results_open_convs(doc, result): + pred = result[0] + return {"submission": {"pred": pred, "question_idx": doc["question_index"], "first_turn_video_caption": doc["first_turn_video_caption"], "target": ""}} + + +def mix_evals_video2text_process_results_freeform(doc, result): + pred = result[0] + ground_truth_str = ", ".join([f'"{gt}"' for gt in doc["target"]]) + ground_truth_str = f"[{ground_truth_str}]" + content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str) + eval_answer, model_name = get_eval(model_response=pred, ground_truth=ground_truth_str, max_tokens=1024) + return { + "submission": {"pred": pred, "question_idx": doc["question_index"], "target": doc["target"], "eval_answer": eval_answer, "gpt_prompt": content}, + "gpt_eval": {"pred": pred, "question_idx": doc["question_index"], "target": doc["target"], "eval_answer": eval_answer, "gpt_prompt": content}, + } + + +def mix_evals_video2text_aggregate_submissions(results, args, task): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"mix_evals_video2text_{task}-{now_date_time}.json" + path = file_utils.generate_submission_file(submission_file_name, args) + with open(path, "w") as f: + json.dump(results, f) + eval_logger.info(f"Submission file saved to {path}") + + +def mix_evals_video2text_gpt_eval(results, args): + score = 0 + for result in results: + eval_answer = result["eval_answer"] + eval_score = re.search(r"([0-9.]+)", eval_answer).group(1) + try: + eval_score = float(eval_score) + except Exception as e: + eval_logger.error(f"Error parsing eval_score: {e}") + eval_score = 0.0 + score += eval_score + + return score / len(results) + + +# Factory into different aggregate +def mix_evals_video2text_aggregate_gen(results, args): + mix_evals_video2text_aggregate_submissions(results, args, "OpenConvs") + + +class MultiChoiceRegexFilter(ExtendedRegexFilter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def apply(self, resps, docs): + filtered_resps = [] + + for r, doc in zip(resps, docs): + # Regex to directly extract the option letter from the model response + option_letter_regex = re.compile(r"\b([A-Z])\.\s+([^\n]*)") + + # Process each response + filtered = [] + for resp in r: + # Try to match the option letter at the start of the response + match = option_letter_regex.match(resp) + if match: + # If a match is found, append the matched letter + filtered.append(match.group(1)) + else: + # If no match, return the original response + filtered.append(resp) + + # Assuming we need the first response that matches or the original response + filtered_resps.append(filtered[0]) + + return filtered_resps diff --git a/lmms_eval/utils.py b/lmms_eval/utils.py index d9452a40..1df086ae 100755 --- a/lmms_eval/utils.py +++ b/lmms_eval/utils.py @@ -246,7 +246,9 @@ def sanitize_model_name(model_name: str, full_path: bool = False) -> str: if full_path: return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) else: - return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name.split("/")[-1]) + parts = model_name.split("/") + last_two = "/".join(parts[-2:]) if len(parts) > 1 else parts[-1] # accommondate for models that are in Hugging Face Hub format like lmms-lab/llava-onevision-qwen2-0.5b + return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", last_two) def sanitize_task_name(task_name: str) -> str: @@ -274,7 +276,7 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]: """ Extracts filenames that correspond to sample results. """ - return [f for f in filenames if "samples" in f and ".json" in f] + return [f for f in filenames if "/samples_" in f and ".json" in f] def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): @@ -593,6 +595,13 @@ def get_datetime_str(timezone="Asia/Singapore"): utc_now = datetime.datetime.now(datetime.timezone.utc) local_time = utc_now.astimezone(tz) return local_time.strftime("%Y%m%d_%H%M%S") + return local_time.strftime("%Y%m%d_%H%M%S") + + +def sanitize_long_string(s, max_length=40): + if len(s) > max_length: + return s[: max_length // 2] + "..." + s[-max_length // 2 :] + return s def ignore_constructor(loader, node): diff --git a/pyproject.toml b/pyproject.toml index 490ce230..f2e7d5bc 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" [project] name = "lmms_eval" -version = "0.2.3" +version = "0.2.3.dev0" authors = [ { name = "LMMMs-Lab Evaluation Team", email = "lmms-lab@outlook.com" }, ] @@ -91,11 +91,16 @@ reka = [ "httpx==0.23.3", "reka-api", ] +qwen = [ + "decord", + "qwen_vl_utils", +] all = [ "vila", "gemini", "reka", "metrics", + "qwen", ] [tool.setuptools.packages.find]