diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index e0688851..c9084e9b 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -2,7 +2,7 @@ class FewShotDataset(object): - def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str = None, split: str = None, dataset_kwargs: dict = None): + def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str = None, split: str = None, dataset_kwargs: dict = None, same_as_eval: bool = False): if dataset is not None and (dataset_path is not None or dataset_name is not None or split is not None or dataset_kwargs is not None): raise ValueError("Cannot provide both `dataset` and other dataset arguments!") self.dataset_path = dataset_path @@ -10,15 +10,22 @@ def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str self.split = split self.dataset = dataset self.dataset_kwargs = dataset_kwargs if dataset_kwargs is not None else {} + self.same_as_eval = same_as_eval + self.fewshot_indices = None - def get_dataset(self): + def get_dataset(self) -> datasets.Dataset: if self.dataset is None: self.dataset = datasets.load_dataset(path=self.dataset_path, name=self.dataset_name, split=self.split, download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, **self.dataset_kwargs) + if self.fewshot_indices: + self.dataset = self.dataset.select(self.fewshot_indices) return self.dataset + def __getitem__(self, item): + return self.get_dataset()[item] + class ContextSampler: - def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: + def __init__(self, docs: FewShotDataset, task, fewshot_indices=None, rnd=None) -> None: self.rnd = rnd assert self.rnd, "must pass rnd to FewShotSampler!" @@ -32,13 +39,13 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: self.doc_to_target = self.task.doc_to_target self.doc_to_choice = self.task.doc_to_choice - self.docs = docs # HF dataset split, provided by task._fewshot_docs() + self.docs: FewShotDataset = docs # HF dataset split, provided by task._fewshot_docs() if fewshot_indices: # subset few-shot docs from - self.docs = self.docs.select(fewshot_indices) + self.docs.fewshot_indices = fewshot_indices def get_context(self, doc, num_fewshot): # draw an extra fewshot sample if using same split as evaluating on - n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot + n_samples = num_fewshot + 1 if self.docs.same_as_eval else num_fewshot # draw `n_samples` docs from fewshot_docs fewshotex = self.sample(n_samples) @@ -71,7 +78,7 @@ def sample(self, n): Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. """ - return self.rnd.sample(self.docs, n) + return self.rnd.sample(self.docs.get_dataset(), n) class FirstNSampler(ContextSampler): diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 96139475..2fff7252 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -312,7 +312,7 @@ def fewshot_docs(self): else: if self.config.num_fewshot is not None: eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.") - return FewShotDataset(self.test_docs()) + return FewShotDataset(self.test_docs(), same_as_eval=True) def _process_doc(self, doc): """ @@ -758,11 +758,15 @@ def fewshot_docs(self): if "fewshot_dataset" in self.config.fewshot_config: fewshot_dataset_config = self.config.fewshot_config["fewshot_dataset"] + if "dataset_path" not in fewshot_dataset_config: + fewshot_dataset_config["dataset_path"] = self.config.dataset_path + same_as_eval = self.config.dataset_path == fewshot_dataset_config["dataset_path"] and self.config.dataset_name == fewshot_dataset_config.get("dataset_name", None) and self.config.test_split == fewshot_dataset_config["split"] return FewShotDataset( dataset_path=fewshot_dataset_config["dataset_path"], dataset_name=fewshot_dataset_config.get("dataset_name", None), split=fewshot_dataset_config["split"], dataset_kwargs=fewshot_dataset_config.get("dataset_kwargs", {}), + same_as_eval=same_as_eval, ) else: if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): diff --git a/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml b/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml index 282b401b..fd485002 100644 --- a/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml +++ b/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml @@ -15,3 +15,10 @@ model_specific_prompt_kwargs: qwen_vl: pre_prompt: "" post_prompt: " Answer:" +fewshot_config: + fewshot_sampler: default + fewshot_dataset: + # dataset_path: lmms-lab/flickr30k + # dataset_name: + split: train + # random_seed: 1234