Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into internal_main_dev
  • Loading branch information
Luodian committed Sep 1, 2024
2 parents fea9321 + 7adbd69 commit e057923
Show file tree
Hide file tree
Showing 67 changed files with 2,776 additions and 41 deletions.
34 changes: 30 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@

# The Evaluation Suite of Large Multimodal Models

[![PyPI](https://img.shields.io/pypi/v/lmms-eval)](https://pypi.org/project/lmms-eval)
![PyPI - Downloads](https://img.shields.io/pypi/dm/lmms-eval)
![GitHub contributors](https://img.shields.io/github/contributors/EvolvingLMMs-Lab/lmms-eval)
[![issue resolution](https://img.shields.io/github/issues-closed-raw/EvolvingLMMs-Lab/lmms-eval)](https://github.com/EvolvingLMMs-Lab/lmms-eval/issues)
[![open issues](https://img.shields.io/github/issues-raw/EvolvingLMMs-Lab/lmms-eval)](https://github.com/EvolvingLMMs-Lab/lmms-eval/issues)

> Accelerating the development of large multimodal models (LMMs) with `lmms-eval`
🏠 [LMMs-Lab Homepage](https://lmms-lab.github.io/) | 🎉 [Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/) | 📚 [Documentation](docs/README.md) | 🤗 [Huggingface Datasets](https://huggingface.co/lmms-lab) | <a href="https://emoji.gg/emoji/1684-discord-thread"><img src="https://cdn3.emoji.gg/emojis/1684-discord-thread.png" width="14px" height="14px" alt="Discord_Thread"></a> [discord/lmms-eval](https://discord.gg/zdkwKUqrPy)
🏠 [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | 🎉 [Blog](https://lmms-lab.framer.ai/blog/lmms-eval-blog) | 📚 [Documentation](docs/README.md) | 🤗 [Huggingface Datasets](https://huggingface.co/lmms-lab) | <a href="https://emoji.gg/emoji/1684-discord-thread"><img src="https://cdn3.emoji.gg/emojis/1684-discord-thread.png" width="14px" height="14px" alt="Discord_Thread"></a> [discord/lmms-eval](https://discord.gg/zdkwKUqrPy)

---

## Annoucement
- [2024-08] 🎉🎉 We welcome the new model [LLaVA-OneVision](https://huggingface.co/papers/2408.03326), [Mantis](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/162), new tasks [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench), [LongVideoBench](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/117), [MMStar](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/158). We provide new feature of SGlang Runtime API for llava-onevision model, please refer the [doc](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/main/docs/commands.md) for inference acceleration

- [2024-07] 🎉🎉 We have released the [technical report](https://arxiv.org/abs/2407.12772) and [LiveBench](https://huggingface.co/spaces/lmms-lab/LiveBench)!

- [2024-07] 👨‍💻👨‍💻 The `lmms-eval/v0.2.1` has been upgraded to support more models, including [LongVA](https://github.com/EvolvingLMMs-Lab/LongVA), [InterVL-2](https://github.com/OpenGVLab/InternVL), [VILA](https://github.com/NVlabs/VILA), and many more evaluation tasks, e.g. [Details Captions](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/136), [MLVU](https://arxiv.org/abs/2406.04264), [WildVision-Bench](https://huggingface.co/datasets/WildVision/wildvision-arena-data), [VITATECS](https://github.com/lscpku/VITATECS) and [LLaVA-Interleave-Bench](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/).

Expand Down Expand Up @@ -48,7 +57,7 @@ cd lmms-eval
pip install -e .
```

If you wanted to test llava, you will have to clone their repo from [LLaVA](https://github.com/haotian-liu/LLaVA) and
If you want to test LLaVA, you will have to clone their repo from [LLaVA](https://github.com/haotian-liu/LLaVA) and
```bash
# for llava 1.5
# git clone https://github.com/haotian-liu/LLaVA
Expand All @@ -68,7 +77,7 @@ You can check the [environment install script](miscs/repr_scripts.sh) and [torch

</details>

If you want to test on caption dataset such as `coco`, `refcoco`, and `nocaps`, you will need to have `java==1.8.0 ` to let pycocoeval api to work. If you don't have it, you can install by using conda
If you want to test on caption dataset such as `coco`, `refcoco`, and `nocaps`, you will need to have `java==1.8.0` to let pycocoeval api to work. If you don't have it, you can install by using conda
```
conda install openjdk=8
```
Expand All @@ -92,6 +101,11 @@ We also provide the raw data exported from Weights & Biases for the detailed res
</details>
<br>

If you want to test [VILA](https://github.com/NVlabs/VILA), you should install the following dependencies:

```bash
pip install s2wrapper@git+https://github.com/bfshi/scaling_on_scales
```

Our Development will be continuing on the main branch, and we encourage you to give us feedback on what features are desired and how to improve the library further, or ask questions, either in issues or PRs on GitHub.

Expand Down Expand Up @@ -165,6 +179,18 @@ python3 -m accelerate.commands.launch \
python3 -m accelerate.commands.launch --num_processes=8 -m lmms_eval --config ./miscs/example_eval.yaml
```

**Evaluation of video model (llava-next-video-32B)**
```bash
accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \
--model llavavid \
--model_args pretrained=lmms-lab/LLaVA-NeXT-Video-32B-Qwen,conv_template=qwen_1_5,video_decode_backend=decord,max_frames_num=32,mm_spatial_pool_mode=average,mm_newline_position=grid,mm_resampler_location=after \
--tasks videomme \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava_vid_32B \
--output_path ./logs/
```

**Evaluation with naive model sharding for bigger model (llava-next-72b)**

```bash
Expand Down Expand Up @@ -199,7 +225,7 @@ Please check [supported models](lmms_eval/models/__init__.py) for more details.

### Supported tasks

Please check [supported tasks](lmms_eval/docs/current_tasks.md) for more details.
Please check [supported tasks](docs/current_tasks.md) for more details.

## Add Customized Model and Dataset

Expand Down
47 changes: 47 additions & 0 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,50 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
```


## Usage with SRT API

> install sglang
```bash
git clone https://github.com/sgl-project/sglang.git
# Current version is tested on #1222
cd sglang;
pip install -e "python[srt]"

# Install FlashInfer CUDA kernels
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

> run sglang backend service with the following command
```bash
# After update, there is no need to use an extra command to setup backend server
# the server will be initialized in the init process

# launch lmms-eval srt_api model
CKPT_PATH=$1
TASK=$2
MODALITY=$3
TP_SIZE=$4
echo $TASK
TASK_SUFFIX="${TASK//,/_}"
echo $TASK_SUFFIX

python3 -m lmms_eval \
--model srt_api \
--model_args modality=$MODALITY,model_version=$CKPT_PATH,tp=$TP_SIZE,host=127.0.0.1,port=30000,timeout=600 \
--tasks $TASK \
--batch_size 1 \
--log_samples \
--log_samples_suffix $TASK_SUFFIX \
--output_path ./logs/
```

You may need to install some dependencies for the above command to work (if you encounter some errors).

```bash
pip install httpx==0.23.3
pip install protobuf==3.20
```


1 change: 1 addition & 0 deletions docs/current_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
- ScreenSpot REG / Instruction Generation (screenspot_reg)
- SeedBench (seedbench)
- SeedBench 2 (seedbench_2)
- SeedBench 2 Plus (seedbench_2_plus)
- ST-VQA (stvqa)
- TextCaps (textcaps)
- TextCaps Validation (textcaps_val)
Expand Down
4 changes: 2 additions & 2 deletions docs/task_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ metric_list:
- metric: mme_cognition_score
aggregation: !function utils.mme_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word or phrase."
Expand All @@ -52,7 +52,7 @@ metadata:
```
You can pay special attention to the `process_results` and `metric_list` fields, which are used to define how the model output is post-processed and scored.
Also, the `model_specific_prompt_kwargs` field is used to define model-specific prompt configurations. The default is set to follow Llava.
Also, the `lmms_eval_specific_kwargs` field is used to define model-specific prompt configurations. The default is set to follow Llava.

PPL-based tasks:
- Seedbench (`lmms_eval/tasks/seedbench/seedbench_ppl.yaml`)
Expand Down
21 changes: 8 additions & 13 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ def _download_from_youtube(path):

if "video" in dataset_kwargs and dataset_kwargs["video"]:
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/")
hf_home = os.path.expanduser(hf_home)
cache_dir = dataset_kwargs["cache_dir"]
cache_dir = os.path.join(hf_home, cache_dir)
accelerator = Accelerator()
Expand Down Expand Up @@ -955,13 +956,13 @@ def concat_tar_parts(tar_parts, output_tar):
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)
self.dataset_no_image = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
download_config=download_config,
**dataset_kwargs if dataset_kwargs is not None else {},
)
if self.config.process_docs is not None:
for split in self.dataset:
if split in [self.config.training_split, self.config.validation_split, self.config.test_split, self.config.fewshot_split]:
self.dataset[split] = self.config.process_docs(self.dataset[split])

# copy dataset, remove image features
self.dataset_no_image = self.dataset.copy()
for doc_name in self.dataset_no_image:
remove_cols = []
features = self.dataset_no_image[doc_name].features
Expand Down Expand Up @@ -994,20 +995,14 @@ def has_test_docs(self) -> bool:

def training_docs(self) -> datasets.Dataset:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.training_split])
return self.dataset[self.config.training_split]

def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.validation_split])
return self.dataset[self.config.validation_split]

def test_docs(self) -> datasets.Dataset:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self.config.test_split]

def fewshot_docs(self):
Expand Down
7 changes: 6 additions & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,13 @@ def evaluate(
}
if log_samples:
results_dict["samples"] = dict(samples)
else:
results_dict = None

return results_dict
with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f:
f.write(f"rank {int(os.environ.get('RANK', 0))} eval done")
while len([file for file in os.listdir(cli_args.output_path) if file.endswith("metric_eval_done.txt")]) < lm._world_size:
time.sleep(1)

else:
return None
Expand Down
5 changes: 4 additions & 1 deletion lmms_eval/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,10 @@ def _collate(x):
split = split[0]
batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
flattened_visuals = self.flatten(batched_visuals)
pixel_values = self.load_image(flattened_visuals, self.image_size).cuda().to(torch.bfloat16)
try:
pixel_values = self.load_image(flattened_visuals, self.image_size).cuda().to(torch.bfloat16)
except IndexError:
pixel_values = None
gen_kwargs = all_gen_kwargs[0]

if "max_new_tokens" not in gen_kwargs:
Expand Down
10 changes: 7 additions & 3 deletions lmms_eval/models/internvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def __init__(

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
Expand Down Expand Up @@ -254,13 +255,16 @@ def generate_until(self, requests) -> List[str]:
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
if self.modality == "image":
visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals]
pixel_values = torch.cat(visuals, dim=0)
num_patches_list = [visual.size(0) for visual in visuals]
if visuals:
visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals]
pixel_values = torch.cat(visuals, dim=0)
num_patches_list = [visual.size(0) for visual in visuals]
image_tokens = ["<image>"] * len(visuals)
image_tokens = " ".join(image_tokens)
contexts = image_tokens + "\n" + contexts
else:
pixel_values = None
num_patch_list = None
response, history = self.model.chat(self.tokenizer, pixel_values, contexts, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True)
elif self.modality == "video":
assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos."
Expand Down
1 change: 1 addition & 0 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
Expand Down
11 changes: 4 additions & 7 deletions lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
device_map: str = "",
chat_template: Optional[str] = None,
use_cache: bool = True,
specified_eot_token_id: Optional[int] = None,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
self.batch_size_per_gpu = int(batch_size)
self.chat_template = chat_template
self.use_cache = use_cache
self.specified_eot_token_id = specified_eot_token_id
if accelerator.num_processes > 1 and device_map == "":
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
Expand Down Expand Up @@ -320,18 +322,13 @@ def _collate(x):
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.specified_eot_token_id,
)
cont = cont[:, inputs["input_ids"].shape[-1] :]
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
if "1.5" in self.pretrained:
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
elif "mistral" in self.pretrained:
text_outputs = text_outputs.split("[/INST]")[-1].strip()
else:
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()

if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")

Expand Down
5 changes: 3 additions & 2 deletions lmms_eval/models/llava_vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ def __init__(
self.max_frames_num = int(max_frames_num)
self.mm_resampler_location = mm_pooling_position
self.delay_load = delay_load

if self.overwrite == True:
overwrite_config = {}
overwrite_config["mm_resampler_type"] = self.mm_resampler_type
overwrite_config["mm_spatial_pool_stride"] = self.mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = self.mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode
overwrite_config["mm_pooling_position"] = self.mm_resampler_location
overwrite_config["mm_newline_position"] = mm_newline_position
overwrite_config["mm_newline_position"] = self.mm_newline_position
overwrite_config["add_faster_video"] = False
overwrite_config["delay_load"] = self.delay_load
# overwrite_config["attn_implementation"] = attn_implementation
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False)
cfg_pretrained = LlavaConfig.from_pretrained(pretrained)
if overwrite_config is not None:
print(f"Overwriting config with {overwrite_config}")
eval_logger.log(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(cfg_pretrained, k, v)
kwargs["torch_dtype"] = torch.float16
Expand Down
1 change: 1 addition & 0 deletions lmms_eval/models/longva.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
Expand Down
Loading

0 comments on commit e057923

Please sign in to comment.