Skip to content

Commit

Permalink
[feat] support video evaluation for qwen2-vl and add mix-evals-video2…
Browse files Browse the repository at this point in the history
…text (#275)

* feat: add new ouput_path saving logic and add evaluation tracker to manage samples saving process

* add: regression test

* add: regression test

* clean: unuseful code

* 🚫 Remove unused import for cleaner code

Eliminated the commented-out import statement for WandbLogger to tidy up the code and enhance readability. This helps maintain focus on active components and prevents confusion over unused code. A cleaner structure contributes to better maintainability in the long run.

No functional changes were made, just a step towards a more streamlined codebase.

* [task] add mix_evals for video evaluation

* Merge branch 'origin/main'

* ✨ Improve model name sanitization for Hugging Face formats

* 🧹 Refactor settings for Llava OneVision model

* ✨ Enhance video and image processing capabilities

- Integrated vision processing for videos and images, improving context handling within the model.
- Added error logging for missing utility dependencies to inform users about installation requirements.
- Updated YAML configuration to standardize prompt handling for various video tasks.
- Bumped version number to indicate ongoing development status.

These changes streamline how visuals are managed in the model, contributing to better assistant responses in tasks involving media.

* 🎉 Enhance W&B logging and video playback

- Added automatic naming for W&B runs if not specified, improving organization.
- Updated video frame rate from 1.0 to 0.5 for better performance and resource management during visual content processing.
- Streamlined W&B logging by removing redundant code, ensuring cleaner execution flow.

These changes optimize logging efficiency and enhance the overall user experience.

* ✨ Refine conversation logic and adjust token limits

- Updated chat template logic for better formatting in responses, ensuring consistent handling of user and assistant roles.
- Reduced maximum new tokens in multiple evaluation files to ensure more concise outputs and improve efficiency.
- Enhanced clarity in few-shot tasks by explicitly labeling question and answer roles in generated text.
- Simplified logging of contextual and target information during evaluation, ensuring better tracking of results.

These adjustments improve the overall output quality and streamline the evaluation processes.

* feat: change qwen2 vl video reading to 0.25 fps to avoid oom

* 🎥 Update video message structure in Qwen2_VL

* Update qwen2_vl.py
  • Loading branch information
Luodian authored and KairuiHu committed Oct 24, 2024
1 parent 91c50aa commit 81297da
Show file tree
Hide file tree
Showing 11 changed files with 458 additions and 46 deletions.
18 changes: 4 additions & 14 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
sys.exit(1)

if args.wandb_args:
if "name" not in args.wandb_args:
name = f"{args.model}_{args.model_args}_{utils.get_datetime_str(timezone=args.timezone)}"
name = utils.sanitize_long_string(name)
args.wandb_args += f",name={name}"
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))

# reset logger
Expand Down Expand Up @@ -506,16 +510,6 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:

batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

# Add W&B logging
if args.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if args.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None, datetime_str=datetime_str)

if args.log_samples:
Expand All @@ -525,10 +519,6 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub:
evaluation_tracker.recreate_metadata_card()

if args.wandb_args:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()

return results, samples
return None, None

Expand Down
12 changes: 0 additions & 12 deletions lmms_eval/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,6 @@ def __init__(
overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode
cfg_pretrained = AutoConfig.from_pretrained(self.pretrained)

if cfg_pretrained.architectures[0] == "LlavaLlamaForCausalLM": # Ugly code, only used in vicuna that needs ROPE
if "224" in cfg_pretrained.mm_vision_tower:
least_token_number = self.max_frames_num * (16 // self.mm_spatial_pool_stride) ** 2 + 1000
else:
least_token_number = self.max_frames_num * (24 // self.mm_spatial_pool_stride) ** 2 + 1000

scaling_factor = math.ceil(least_token_number / 4096)
if scaling_factor >= 2:
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor

llava_model_args["overwrite_config"] = overwrite_config
try:
# Try to load the model with the multimodal argument
Expand Down
66 changes: 49 additions & 17 deletions lmms_eval/models/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import base64
from io import BytesIO
from typing import List, Optional, Tuple, Union

import decord
import torch
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration

Expand All @@ -11,6 +15,11 @@
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

try:
from qwen_vl_utils import process_vision_info
except ImportError:
eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`")


@register_model("qwen2_vl")
class Qwen2_VL(lmms):
Expand Down Expand Up @@ -176,30 +185,54 @@ def _collate(x):
contexts[i] = contexts[i].replace("<image>", "")

messages = []

if len(visuals) == 0:
for context in contexts:
message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "text", "text": context}]}]
messages.append(message)
else:
for _, context in zip(visuals, contexts):
message = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": context}]}]
messages.append(message)
processed_visuals = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")

message = [{"role": "system", "content": "You are a helpful assistant."}]

if len(visuals) > 0:
visual = visuals[i] if i < len(visuals) else None
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
vr = decord.VideoReader(visual)
first_frame = vr[0].asnumpy()
height, width = first_frame.shape[:2]
max_pixels = height * width
message.append({"role": "user", "content": [{"type": "video", "video": visual, "max_pixels": max_pixels}, {"type": "text", "text": context}]})
elif isinstance(visual, Image.Image): # Single image
base64_image = visual.convert("RGB")
buffer = BytesIO()
base64_image.save(buffer, format="JPEG")
base64_bytes = base64.b64encode(buffer.getvalue())
base64_string = base64_bytes.decode("utf-8")
message.append({"role": "user", "content": [{"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"}, {"type": "text", "text": context}]})
elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images
image_content = []
for v in visual:
base64_image = v.convert("RGB")
buffer = BytesIO()
base64_image.save(buffer, format="JPEG")
base64_bytes = base64.b64encode(buffer.getvalue())
base64_string = base64_bytes.decode("utf-8")
image_content.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"})
message.append({"role": "user", "content": image_content + [{"type": "text", "text": context}]})
else:
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
else:
message.append({"role": "user", "content": [{"type": "text", "text": context}]})

messages.append(message)

texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages]
inputs = self.processor(text=texts, images=[visuals], padding=True, return_tensors="pt")
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")

if self.device_map == "auto":
inputs = inputs.to("cuda")
else:
inputs = inputs.to(self.device)

# preconfigure gen_kwargs with defaults
if "image_sizes" not in gen_kwargs:
try:
gen_kwargs["image_sizes"] = [visuals[0].size]
except:
gen_kwargs["image_sizes"] = None
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 128
if "temperature" not in gen_kwargs:
Expand All @@ -221,7 +254,6 @@ def _collate(x):
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
# kwargs=gen_kwargs
)

generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
Expand Down
16 changes: 16 additions & 0 deletions lmms_eval/tasks/mix_evals/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
dataset_kwargs:
cache_dir: mix_evals_video2text
token: true
video: true
dataset_path: lmms-lab/MixEvals_Video2Text
lmms_eval_specific_kwargs:
default:
post_prompt: ""
pre_prompt: ""
gpt4v:
post_prompt: ""
pre_prompt: These are frames from a video. Please answer the following questions about the video.
metadata:
gpt_eval_model_name: gpt-4o-mini
modality: video
version: 0
5 changes: 5 additions & 0 deletions lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
group: mix_evals_video2text
task:
# - mix_evals_video2text_openconv
- mix_evals_video2text_mc
- mix_evals_video2text_freeform
25 changes: 25 additions & 0 deletions lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dataset_name: "video2text_closeended_free-form"
task: "mix_evals_video2text_freeform"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual
doc_to_text: !function utils.mix_evals_video2text_doc_to_text
doc_to_target: "{{target}}"
process_results: !function utils.mix_evals_video2text_process_results_freeform
metric_list:
- metric: gpt_eval
aggregation: !function utils.mix_evals_video2text_gpt_eval
higher_is_better: true

generation_kwargs:
max_new_tokens: 16

include: _default_template_yaml

lmms_eval_specific_kwargs:
default:
pre_prompt: "These are frames from a video. Please answer the following questions about the video."
post_prompt: "Answer the question using a single word or phrase."
gpt4v:
pre_prompt: "These are frames from a video. Please answer the following questions about the video with a short phrase."
post_prompt: ""
34 changes: 34 additions & 0 deletions lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
include: _default_template_yaml
dataset_name: "video2text_closeended_multiple-choice"
task: "mix_evals_video2text_mc"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual
doc_to_text: !function utils.mix_evals_video2text_doc_to_text
doc_to_target: "{{target}}"

generation_kwargs:
max_new_tokens: 5

metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true

filter_list:
- name: "flexible-extract"
filter:
- function: !function utils.MultiChoiceRegexFilter
group_select: 0
ignore_case: true
ignore_punctuation: true

lmms_eval_specific_kwargs:
default:
pre_prompt: "These are frames from a video. Please answer the following questions about the video."
post_prompt: "Answer with the option's letter from the given choices directly."
gpt4v:
pre_prompt: "These are frames from a video. Please answer the following questions about the video."
post_prompt: "Answer with the option's letter from the given choices directly."
22 changes: 22 additions & 0 deletions lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
include: _default_template_yaml
dataset_name: "video2text_openended"
task: "mix_evals_video2text_openconv"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual
doc_to_text: !function utils.mix_evals_video2text_doc_to_text_open_convs
doc_to_target: ""
process_results: !function utils.mix_evals_video2text_process_results_open_convs

metric_list:
- metric: submission
aggregation: !function utils.mix_evals_video2text_aggregate_gen
higher_is_better: true

lmms_eval_specific_kwargs:
default:
pre_prompt: "These are frames from a video. Please answer the following questions about the video."
post_prompt: ""
gpt4v:
pre_prompt: "These are frames from a video. Please answer the following questions about the video."
post_prompt: ""
Loading

0 comments on commit 81297da

Please sign in to comment.