diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index 7c89535a..e43b588f 100755 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -264,6 +264,7 @@ def parse_eval_args() -> argparse.Namespace: action="store_true", help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub", ) + parser.add_argument("--process_with_media", action="store_true", help="Whether you will process you dataset with audio, image. By default set to False" "In case some benchmarks need to be processed with media, set this flag to True.") args = parser.parse_args() return args diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 937b0fcb..032acc93 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -29,7 +29,7 @@ import datasets import numpy as np from accelerate import Accelerator -from datasets import DownloadConfig, Image, Sequence +from datasets import Audio, DownloadConfig, Image, Sequence from huggingface_hub import snapshot_download from loguru import logger as eval_logger from PIL import ImageFile @@ -430,9 +430,10 @@ def build_all_requests( if cache_requests and (not cached_instances or rewrite_requests_cache) and limit is not None: limit = None - doc_id_docs = list(self.doc_iterator(rank=rank, limit=limit, world_size=world_size)) + doc_id_docs = self.doc_iterator(rank=rank, limit=limit, world_size=world_size) + doc_iterator_for_counting = itertools.islice(range(len(self.test_docs())), rank, limit, world_size) if self.has_test_docs() else itertools.islice(range(len(self.validation_docs())), rank, limit, world_size) - num_docs = len(doc_id_docs) + num_docs = sum(1 for _ in doc_iterator_for_counting) for doc_id, doc in tqdm( doc_id_docs, @@ -1064,6 +1065,8 @@ def concat_tar_parts(tar_parts, output_tar): remove_cols.append(feature) elif isinstance(features[feature], Sequence) and isinstance(features[feature].feature, Image): remove_cols.append(feature) + elif isinstance(features[feature], Audio): + remove_cols.append(feature) for remove_col in remove_cols: self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(remove_col) @@ -1093,10 +1096,27 @@ def validation_docs(self) -> datasets.Dataset: if self.has_validation_docs(): return self.dataset[self.config.validation_split] + def validation_docs_no_media(self) -> datasets.Dataset: + if self.has_validation_docs(): + return self.dataset_no_image[self.config.validation_split] + def test_docs(self) -> datasets.Dataset: if self.has_test_docs(): return self.dataset[self.config.test_split] + def test_docs_no_media(self) -> datasets.Dataset: + if self.has_test_docs(): + return self.dataset_no_image[self.config.test_split] + + @property + def eval_docs_no_media(self) -> Union[datasets.Dataset, List[dict]]: + if self.has_test_docs(): + return self.test_docs_no_media() + elif self.has_validation_docs(): + return self.validation_docs_no_media() + else: + raise ValueError(f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!") + def fewshot_docs(self): if self.config.fewshot_split is not None: return self.dataset[self.config.fewshot_split] diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index c6460e9d..11e6c1c5 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -483,7 +483,10 @@ def evaluate( instances.sort(key=lambda x: x.idx) # iterate over different filters used for filter_key in task.instances[0].filtered_resps.keys(): - doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE) + if not cli_args.process_with_media: + doc_iterator = create_iterator(enumerate(task.eval_docs_no_media), rank=RANK, limit=int(limit) if limit else None, world_size=WORLD_SIZE) + else: + doc_iterator = task.doc_iterator(rank=RANK, limit=limit, world_size=WORLD_SIZE) doc_iterator_for_counting = itertools.islice(range(len(task.test_docs())), RANK, limit, WORLD_SIZE) if task.has_test_docs() else itertools.islice(range(len(task.validation_docs())), RANK, limit, WORLD_SIZE) total_docs = sum(1 for _ in doc_iterator_for_counting) pbar = tqdm(total=total_docs, desc=f"Postprocessing", disable=(RANK != 0))