Skip to content

Commit

Permalink
few_shot dataset loading
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 27, 2024
1 parent 9dfb53a commit 2571960
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
19 changes: 15 additions & 4 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TaskConfig(dict):
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable = None
Expand Down Expand Up @@ -550,8 +550,12 @@ def __init__(self, model_name) -> None: # TODO no super() call here
self._filters.append(filter_pipeline)
else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self.config.task == "flickr30k_test":
pass # TODO: for test, will delete later
if self.config.fewshot_config is not None:
self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234))
random_seed = self.config.fewshot_config.get("random_seed", 1234)
sampler_function = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")
self.sampler = sampler_function(list(self.fewshot_docs()), self, rnd=random.Random(random_seed))

if self.has_test_docs():
self.task_docs = self.test_docs()
Expand Down Expand Up @@ -742,8 +746,15 @@ def test_docs(self) -> datasets.Dataset:
return self.dataset[self.config.test_split]

def fewshot_docs(self):
if self.config.fewshot_split is not None:
return self.dataset[self.config.fewshot_split]
if "fewshot_dataset" in self.config.fewshot_config:
fewshot_dataset_config = self.config.fewshot_config["fewshot_dataset"]
return datasets.load_dataset(
path=fewshot_dataset_config["dataset_path"],
name=fewshot_dataset_config.get("dataset_name", None),
split=fewshot_dataset_config["fewshot_split"],
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
**fewshot_dataset_config["dataset_kwargs"] if "dataset_kwargs" in fewshot_dataset_config else {},
)
else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning(f"Task '{self.config.task}': " "num_fewshot > 0 but fewshot_split is None. " "using preconfigured rule.")
Expand Down
9 changes: 8 additions & 1 deletion lmms_eval/tasks/flickr30k/flickr30k_test.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
dataset_path: lmms-lab/flickr30k
dataset_kwargs:
token: True
task : "flickr30k_test"
task: "flickr30k_test"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.flickr_doc_to_visual
Expand Down Expand Up @@ -40,5 +40,12 @@ metric_list:
#- metric: flickr_SPICE
# aggregation : !function utils.flickr_spice
# higher_is_better : true
fewshot_config:
fewshot_sampler: default
fewshot_dataset:
dataset_path: lmms-lab/flickr30k
# dataset_name:
split: train
# random_seed: 1234
metadata:
- version: 0.0

0 comments on commit 2571960

Please sign in to comment.