From f6fc367689f508af9db2e4f697d3d20957fa2e11 Mon Sep 17 00:00:00 2001 From: Fanyi Pu Date: Sun, 31 Mar 2024 00:32:22 +0800 Subject: [PATCH] Add description to Context class and update ConfigurableTask to include the description in labeled examples --- lmms_eval/api/samplers.py | 10 ++++++++-- lmms_eval/api/task.py | 12 ++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index f25a2161..43f95319 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -12,7 +12,7 @@ def get_images(self, doc_to_visual): class Context(object): - def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n"): + def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n", description = None): self.task = task self.config = task._config @@ -25,6 +25,12 @@ def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str self.few_shot_delimiter = few_shot_delimiter self.contexts = [] + + if description: + self.add_description(description) + + def add_description(self, description): + self.contexts = [description] + self.contexts def get_question(self, doc): text = self.doc_to_text(doc) @@ -109,7 +115,7 @@ def __init__(self, docs: FewShotDataset, task, fewshot_indices=None, rnd=None) - if fewshot_indices: # subset few-shot docs from self.docs.fewshot_indices = fewshot_indices - def get_context(self, doc, num_fewshot): + def get_context(self, doc, num_fewshot) -> Context: # draw an extra fewshot sample if using same split as evaluating on n_samples = num_fewshot + 1 if self.docs.same_as_eval else num_fewshot diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index abd36470..b365673a 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -788,20 +788,20 @@ def fewshot_context(self, doc_id, num_fewshot, split): doc = self.dataset_no_image[split][doc_id] if num_fewshot == 0: # always prepend the (possibly empty) task description - labeled_examples = self.config.description + labeled_examples = [self.config.description] else: - labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot) + labeled_examples = [self.config.description] + self.sampler.get_context(doc, num_fewshot).contexts example = self.doc_to_text(doc) if type(example) == str: - return labeled_examples + example + return labeled_examples + [example] elif type(example) == list: - return [labeled_examples + ex for ex in example] + return labeled_examples + [ex for ex in example] elif type(example) == int: if self.config.doc_to_choice is not None: choices = self.doc_to_choice(doc) - return labeled_examples + choices[example] + return labeled_examples + [choices[example]] else: - return labeled_examples + str(example) + return labeled_examples + [str(example)] def apply_filters(self): if hasattr(self, "_filters"):