Skip to content

Commit

Permalink
[Feat][Task] Add multi-round evaluation in llava-onevision; Add MMSea…
Browse files Browse the repository at this point in the history
…rch Benchmark (#277)

* Add MMSearch. Add generate_until_multi_round function in LLaVA-OneVision

* Fix linting error

* Update mutli-gpu end2end inference support.
Update README.md about multi-round interaction.

* Fix linting error

* Fix linting error
  • Loading branch information
CaraJ7 authored and KairuiHu committed Oct 24, 2024
1 parent 81297da commit 77dfbac
Show file tree
Hide file tree
Showing 32 changed files with 2,245 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
---

## Annoucement
- [2024-09] 🎉🎉 We welcome the new task [MMSearch](https://mmsearch.github.io/).
- [2024-09] 🎉🎉 We welcome the new task [MME-RealWorld](https://mme-realworld.github.io/) for inference acceleration
- [2024-09] ⚙️️⚙️️️️ We upgrade `lmms-eval` to `0.2.3` with more tasks and features. We support a compact set of language tasks evaluations (code credit to [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)), and we remove the registration logic at start (for all models and tasks) to reduce the overhead. Now `lmms-eval` only launches necessary tasks/models. Please check the [release notes](https://github.com/EvolvingLMMs-Lab/lmms-eval/releases/tag/v0.2.3) for more details.
- [2024-08] 🎉🎉 We welcome the new model [LLaVA-OneVision](https://huggingface.co/papers/2408.03326), [Mantis](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/162), new tasks [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench), [LongVideoBench](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/117), [MMStar](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/158). We provide new feature of SGlang Runtime API for llava-onevision model, please refer the [doc](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/main/docs/commands.md) for inference acceleration
Expand Down
43 changes: 41 additions & 2 deletions docs/task_guide.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@ metadata:
- version: 0.0
```

Multi-round-generation-based tasks:

- MMSearch(`lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml`)

```yaml
dataset_path: CaraJ/MMSearch
dataset_name: end2end
dataset_kwargs:
token: False
task: "mmsearch_end2end"
test_split: end2end
output_type: generate_until_multi_round # Note that here we use the new output_type here for multi-round generation. It basicly follows generate_until but incorporate multi-round inference
doc_to_visual: !function lmms_eval_utils.mmsearch_end2end_doc_to_visual
doc_to_text: !function lmms_eval_utils.mmsearch_end2end_doc_to_text
doc_to_target: "answer"
generation_kwargs:
until:
- "ASSISTANT:"
max_new_tokens: 512
temperature: 0
top_p: 0
num_beams: 1
do_sample: false
process_results: !function lmms_eval_utils.mmsearch_end2end_process_results
metric_list:
- metric: end2end_f1_score
aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_f1_score
higher_is_better: true
- metric: requery_score
aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_req_score
higher_is_better: true
lmms_eval_specific_kwargs: # Note that here we cache the result of every sample whenever the it is inferenced
middle_resules_dir: /data1/zrr/jdz/mmsearch/mmsearch_middile_results
result_cache_dir: /data1/zrr/jdz/mmsearch/mmsearch_result_cache_dir
```


## Configurations

Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
Expand All @@ -96,8 +134,9 @@ Dataset configuration options:
- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.

Prompting / in-context formatting options:
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model
- **doc_to_visial** (`Union[Callable, str]`, *optional*) — Function to process a sample into the appropriate input images for the model.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model.

For multi-round generation, (e.g., MMSearch), the function accepts additional parameters about the round index, previous round information and previous model output. It should return the input image for the next round, input text for the next round, a boolean indicating if round inference should terminate, model outputs from all rounds, and extra information from previous rounds.
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.

Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/api/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@dataclass
class Instance:
request_type: Literal["loglikelihood", "generate_until"]
request_type: Literal["loglikelihood", "generate_until", "generate_until_multi_round"]
arguments: tuple
idx: int
metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def mean_stderr(arr):
@register_metric(
metric="bypass",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice", "generate_until"],
output_type=["loglikelihood", "multiple_choice", "generate_until", "generate_until_multi_round"],
aggregation="bypass",
)
def bypass(items):
Expand Down Expand Up @@ -368,7 +368,7 @@ def f1_fn(items): # This is a passthrough function
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
output_type=["generate_until", "generate_until_multi_round"],
aggregation="bleu",
)
def bleu_fn(items): # This is a passthrough function
Expand All @@ -378,7 +378,7 @@ def bleu_fn(items): # This is a passthrough function
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
output_type=["generate_until", "generate_until_multi_round"],
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
Expand All @@ -388,7 +388,7 @@ def chrf_fn(items): # This is a passthrough function
@register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
output_type=["generate_until", "generate_until_multi_round"],
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
Expand Down
21 changes: 20 additions & 1 deletion lmms_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ def generate_until(self, requests) -> List[str]:
"""
pass

@abc.abstractmethod
def generate_until_multi_round(self, requests) -> List[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
generation_kwargs: dict
Generation Kwargs
'visual_list: list[dict]'
Visual input to the model. Can be None.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
pass

@classmethod
def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T:
"""
Expand Down Expand Up @@ -160,7 +179,7 @@ def fn(requests):
eval_logger.info(f"Loading '{attr}' responses from cache '{self.cache_db}' where possible...")
for req in tqdm(requests):
hsh = hash_args(attr, req.args)
if attr == "generate_until" and req.args[1].get("do_sample", False):
if attr in ["generate_until", "generate_until_multi_round"] and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
Expand Down
1 change: 1 addition & 0 deletions lmms_eval/api/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def decorate(fn):
],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
"generate_until_multi_round": ["exact_match"],
}


Expand Down
14 changes: 9 additions & 5 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import subprocess
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from functools import partial
from glob import glob
from typing import (
Any,
Expand Down Expand Up @@ -59,6 +60,7 @@
"loglikelihood",
"multiple_choice",
"generate_until",
"generate_until_multi_round",
]


Expand Down Expand Up @@ -130,17 +132,17 @@ def __post_init__(self) -> None:
raise ValueError("Got both a `group` and `tag` entry within a TaskConfig. Please use one or the other--`group` values will be deprecated in v0.4.4.")

if self.generation_kwargs is not None:
if self.output_type != "generate_until":
if "generate_until" not in self.output_type:
eval_logger.warning(f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!")
assert self.output_type != "generate_until"
assert "generate_until" not in self.output_type

if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(self.generation_kwargs["temperature"])

if "until" not in self.generation_kwargs:
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
if "generate_until" in self.output_type:
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": None if self.fewshot_delimiter is None else [self.fewshot_delimiter],
Expand Down Expand Up @@ -1380,6 +1382,8 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst

elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, doc_id, self.config.task, split)
elif self.OUTPUT_TYPE == "generate_until_multi_round":
arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, partial(self.config.doc_to_text, lmms_eval_specific_kwargs=self.lmms_eval_specific_kwargs), doc_id, self.config.task, split)
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)

# TODO: we add a full_docs interface here for some evaluations that needs to access the full datasets during process_results function. we may have better ways to handle this.
Expand Down Expand Up @@ -1466,7 +1470,7 @@ def process_results(self, doc, results, full_docs=None):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info

elif self.OUTPUT_TYPE == "generate_until":
elif "generate_until" in self.OUTPUT_TYPE:
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
Expand Down Expand Up @@ -1524,7 +1528,7 @@ def process_results(self, doc, results, full_docs=None):
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood','generate_until' or 'multiple_choice'",
"'loglikelihood','generate_until', 'generate_until_multi_round', or 'multiple_choice'",
)

return result_dict
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _adjust_config(task_dict):
if task_obj is None:
continue
lm.task_dict[task_name] = task_obj.dataset
if task_obj.get_config("output_type") == "generate_until":
if "generate_until" in task_obj.get_config("output_type"):
if gen_kwargs is not None:
task_obj.set_config(key="generation_kwargs", value=gen_kwargs, update=True)

Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _generate_dataset(self, data: List[Dict[str, Any]], config: Dict[str, Any])
choices = ["\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) for x in data]
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
filtered_resps = [np.argmax([n[0] for n in x["filtered_resps"]]) for x in data]
elif config["output_type"] == "generate_until":
elif "generate_until" in config["output_type"]:
instance = [x["arguments"][0][0] for x in data]
resps = [x["resps"][0][0] for x in data]
filtered_resps = [x["filtered_resps"][0] for x in data]
Expand Down
Loading

0 comments on commit 77dfbac

Please sign in to comment.