diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 3262937c..aa6edf9a 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -64,7 +64,7 @@ class TaskConfig(dict): training_split: str = None validation_split: str = None test_split: str = None - fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) + # fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) # formatting / prompting options. # see docs/advanced_task_guide.md for more info process_docs: Callable = None @@ -550,8 +550,12 @@ def __init__(self, model_name) -> None: # TODO no super() call here self._filters.append(filter_pipeline) else: self._filters = [build_filter_ensemble("none", [["take_first", None]])] + if self.config.task == "flickr30k_test": + pass # TODO: for test, will delete later if self.config.fewshot_config is not None: - self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234)) + random_seed = self.config.fewshot_config.get("random_seed", 1234) + sampler_function = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default") + self.sampler = sampler_function(list(self.fewshot_docs()), self, rnd=random.Random(random_seed)) if self.has_test_docs(): self.task_docs = self.test_docs() @@ -742,8 +746,15 @@ def test_docs(self) -> datasets.Dataset: return self.dataset[self.config.test_split] def fewshot_docs(self): - if self.config.fewshot_split is not None: - return self.dataset[self.config.fewshot_split] + if "fewshot_dataset" in self.config.fewshot_config: + fewshot_dataset_config = self.config.fewshot_config["fewshot_dataset"] + return datasets.load_dataset( + path=fewshot_dataset_config["dataset_path"], + name=fewshot_dataset_config.get("dataset_name", None), + split=fewshot_dataset_config["fewshot_split"], + download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, + **fewshot_dataset_config["dataset_kwargs"] if "dataset_kwargs" in fewshot_dataset_config else {}, + ) else: if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): eval_logger.warning(f"Task '{self.config.task}': " "num_fewshot > 0 but fewshot_split is None. " "using preconfigured rule.") diff --git a/lmms_eval/tasks/flickr30k/flickr30k_test.yaml b/lmms_eval/tasks/flickr30k/flickr30k_test.yaml index 737d9ff4..776eae39 100644 --- a/lmms_eval/tasks/flickr30k/flickr30k_test.yaml +++ b/lmms_eval/tasks/flickr30k/flickr30k_test.yaml @@ -1,7 +1,7 @@ dataset_path: lmms-lab/flickr30k dataset_kwargs: token: True -task : "flickr30k_test" +task: "flickr30k_test" test_split: test output_type: generate_until doc_to_visual: !function utils.flickr_doc_to_visual @@ -40,5 +40,12 @@ metric_list: #- metric: flickr_SPICE # aggregation : !function utils.flickr_spice # higher_is_better : true +fewshot_config: + fewshot_sampler: default + fewshot_dataset: + dataset_path: lmms-lab/flickr30k + # dataset_name: + split: train + # random_seed: 1234 metadata: - version: 0.0 \ No newline at end of file