Skip to content

Commit

Permalink
add MMMU - a benchmark for vLM
Browse files Browse the repository at this point in the history
  • Loading branch information
mickqian committed Feb 22, 2025
1 parent 14d9061 commit 15c25d1
Show file tree
Hide file tree
Showing 9 changed files with 1,020 additions and 8 deletions.
15 changes: 15 additions & 0 deletions benchmark/mmmu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## Run evaluation

### Evaluate sglang

```
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
```

It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.

### Evaluate hf

```
python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct
```
119 changes: 119 additions & 0 deletions benchmark/mmmu/bench_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Bench the huggingface vLM with benchmark MMMU
Usage:
python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct
The eval output will be logged
"""

import argparse
import random

import torch
from bench_sglang import EvalArgs, prepare_samples
from data_utils import save_json
from eval_utils import eval_result, get_sampling_params, parse_multi_choice_response
from tqdm import tqdm
from transformers import AutoModelForImageTextToText, AutoProcessor


@torch.no_grad()
def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args)

model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
)
model = model.eval().cuda()
model = torch.compile(model)

processor = AutoProcessor.from_pretrained(
args.model_path, torch_dtype="auto", device_map="auto"
)

samples = prepare_samples(eval_args)
out_samples = dict()

sampling_params = get_sampling_params(eval_args)

answer_dict = {}
for sample in tqdm(samples):
prompt = sample["final_input_prompt"]
image = sample["image"]
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
if image is not None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prefix},
{
"type": "image",
"image": image,
},
{"type": "text", "text": suffix},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
).to(model.device)

generated_ids = model.generate(**inputs, **sampling_params)

response = processor.decode(
generated_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[len(text) :]
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)

else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"

if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans

torch.cuda.empty_cache()
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": sample["answer"],
}

args.output_path = f"{args.model_path}_val_hf.json"
save_json(args.output_path, out_samples)
eval_result(output_path=args.output_path, answer_dict=answer_dict)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
EvalArgs.add_cli_args(parser)
args = parser.parse_args()

eval_mmmu(args)
101 changes: 101 additions & 0 deletions benchmark/mmmu/bench_sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Bench the sglang-hosted vLM with benchmark MMMU
Usage:
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
The eval output will be logged
"""

import argparse
import dataclasses
import random
import re
from io import BytesIO

from data_utils import save_json
from eval_utils import (
EvalArgs,
eval_result,
get_sampling_params,
parse_multi_choice_response,
prepare_samples,
)
from tqdm import tqdm

from sglang import Engine
from sglang.srt.conversation import chat_templates
from sglang.srt.server_args import ServerArgs


def eval_mmmu(args):
server_args = ServerArgs.from_cli_args(args)
eval_args = EvalArgs.from_cli_args(args)

if server_args.chat_template is None:
raise ValueError("Chat template must be provided for this benchmark")

samples = prepare_samples(eval_args)

backend = Engine(**dataclasses.asdict(server_args))

out_samples = dict()

sampling_params = get_sampling_params(eval_args)

conv = chat_templates[server_args.chat_template].copy()
image_token = conv.image_token
answer_dict = {}
for sample in tqdm(samples):
prompt = sample["final_input_prompt"]
image = sample["image"]
bytes_io = BytesIO()
image.save(bytes_io, format="PNG")
png_bytes = bytes_io.getvalue()

prompt = re.sub(r"<[^>]*>", image_token, prompt)

if image is not None:
gen_out = backend.generate(
prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params
)["text"]

response = gen_out
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)

else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"

if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans

# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": (
sample["correct_choice"]
if "correct_choice" in samples
else sample["answer"]
),
}

args.output_path = f"{args.model_path}_val_sglang.json"
save_json(args.output_path, out_samples)
eval_result(output_path=args.output_path, answer_dict=answer_dict)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
EvalArgs.add_cli_args(parser)
args = parser.parse_args()

eval_mmmu(args)
Loading

0 comments on commit 15c25d1

Please sign in to comment.