Skip to content

Commit

Permalink
Add same_as_eval parameter to FewShotDataset constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 28, 2024
1 parent 89e89e0 commit 96129e0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
21 changes: 14 additions & 7 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,30 @@


class FewShotDataset(object):
def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str = None, split: str = None, dataset_kwargs: dict = None):
def __init__(self, dataset=None, *, dataset_path: str = None, dataset_name: str = None, split: str = None, dataset_kwargs: dict = None, same_as_eval: bool = False):
if dataset is not None and (dataset_path is not None or dataset_name is not None or split is not None or dataset_kwargs is not None):
raise ValueError("Cannot provide both `dataset` and other dataset arguments!")
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.split = split
self.dataset = dataset
self.dataset_kwargs = dataset_kwargs if dataset_kwargs is not None else {}
self.same_as_eval = same_as_eval
self.fewshot_indices = None

def get_dataset(self):
def get_dataset(self) -> datasets.Dataset:
if self.dataset is None:
self.dataset = datasets.load_dataset(path=self.dataset_path, name=self.dataset_name, split=self.split, download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, **self.dataset_kwargs)
if self.fewshot_indices:
self.dataset = self.dataset.select(self.fewshot_indices)
return self.dataset

def __getitem__(self, item):
return self.get_dataset()[item]


class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
def __init__(self, docs: FewShotDataset, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!"

Expand All @@ -32,13 +39,13 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice

self.docs = docs # HF dataset split, provided by task._fewshot_docs()
self.docs: FewShotDataset = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)
self.docs.fewshot_indices = fewshot_indices

def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
n_samples = num_fewshot + 1 if self.docs.same_as_eval else num_fewshot

# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
Expand Down Expand Up @@ -71,7 +78,7 @@ def sample(self, n):
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""

return self.rnd.sample(self.docs, n)
return self.rnd.sample(self.docs.get_dataset(), n)


class FirstNSampler(ContextSampler):
Expand Down
6 changes: 5 additions & 1 deletion lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def fewshot_docs(self):
else:
if self.config.num_fewshot is not None:
eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.")
return FewShotDataset(self.test_docs())
return FewShotDataset(self.test_docs(), same_as_eval=True)

def _process_doc(self, doc):
"""
Expand Down Expand Up @@ -758,11 +758,15 @@ def fewshot_docs(self):

if "fewshot_dataset" in self.config.fewshot_config:
fewshot_dataset_config = self.config.fewshot_config["fewshot_dataset"]
if "dataset_path" not in fewshot_dataset_config:
fewshot_dataset_config["dataset_path"] = self.config.dataset_path
same_as_eval = self.config.dataset_path == fewshot_dataset_config["dataset_path"] and self.config.dataset_name == fewshot_dataset_config.get("dataset_name", None) and self.config.test_split == fewshot_dataset_config["split"]
return FewShotDataset(
dataset_path=fewshot_dataset_config["dataset_path"],
dataset_name=fewshot_dataset_config.get("dataset_name", None),
split=fewshot_dataset_config["split"],
dataset_kwargs=fewshot_dataset_config.get("dataset_kwargs", {}),
same_as_eval=same_as_eval,
)
else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
Expand Down
7 changes: 7 additions & 0 deletions lmms_eval/tasks/textvqa/_default_template_textvqa_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ model_specific_prompt_kwargs:
qwen_vl:
pre_prompt: ""
post_prompt: " Answer:"
fewshot_config:
fewshot_sampler: default
fewshot_dataset:
# dataset_path: lmms-lab/flickr30k
# dataset_name:
split: train
# random_seed: 1234

0 comments on commit 96129e0

Please sign in to comment.