diff --git a/community_tasks/arabic_evals.py b/community_tasks/arabic_evals.py index 323120cd..f575b5f0 100644 --- a/community_tasks/arabic_evals.py +++ b/community_tasks/arabic_evals.py @@ -72,7 +72,6 @@ def mmlu_arabic(line, task_name: str = None): choices=LETTER_INDICES_AR[:4], gold_index=gold_ix, instruction=instruction, - target_for_fewshot_sorting=LETTER_INDICES_AR[gold_ix], ) @@ -181,7 +180,6 @@ def arabic_exams(line, task_name: str = None): choices=LETTER_INDICES_AR[:4], gold_index=answer_index, instruction=instruction, - target_for_fewshot_sorting=choices[answer_index], ) @@ -231,7 +229,6 @@ def alghafa_prompt(line, task_name: str = None): choices=choices, gold_index=answer_index, instruction=instruction, - target_for_fewshot_sorting=choices[answer_index], ) @@ -371,7 +368,6 @@ def __init__( def boolq_prompt_arabic(line, task_name: str = None): question = line["question"] passage = line["passage"] - answer = "نعم" if line["answer"] else "لا" instruction = "بناء على المقطع التالي، أجب عن السؤال ب نعم أو لا" query = f"""{instruction} المقطع : @@ -387,7 +383,6 @@ def boolq_prompt_arabic(line, task_name: str = None): choices=["نعم", "لا"], gold_index=0 if line["answer"] else 1, instruction=instruction, - target_for_fewshot_sorting=answer, ) @@ -423,7 +418,6 @@ def copa_prompt_arabic(line, task_name: str = None): choices=choices, gold_index=answer, instruction="", - target_for_fewshot_sorting=choices[answer], ) @@ -468,7 +462,6 @@ def hellaswag_prompt_arabic(line, task_name: str = None): choices=endings, gold_index=answer_index, instruction=instruction, - target_for_fewshot_sorting=endings[answer_index], ) @@ -506,7 +499,6 @@ def toxigen_prompt_arabic(line, task_name: str = None): choices=["لا", "نعم"], gold_index=label, instruction=instruction, - target_for_fewshot_sorting="نعم" if label == 1 else "لا", ) @@ -558,7 +550,6 @@ def sciq_prompt_arabic(line, task_name: str = None): choices=choices, gold_index=answer_index, instruction=instruction, - target_for_fewshot_sorting=choices[answer_index], ) diff --git a/community_tasks/serbian_eval.py b/community_tasks/serbian_eval.py index e7948510..3b49c4cb 100644 --- a/community_tasks/serbian_eval.py +++ b/community_tasks/serbian_eval.py @@ -200,8 +200,6 @@ def serbian_eval_prompt(line: dict, task_name: Optional[str] = None) -> Doc: - choices (list of str): The list of available answer choices. - gold_index (int): The index of the correct answer. - instruction (str): The instruction shown to the user in Serbian. - - target_for_fewshot_sorting (Union[str, list of str]): The correct answer, either as a - string (for regular tasks) or a list of strings (for MMLU tasks). """ question = line["query"] @@ -226,16 +224,12 @@ def serbian_eval_prompt(line: dict, task_name: Optional[str] = None) -> Doc: query += "\n\nKrajnji odgovor:" - # Finalize target_for_fewshot_sorting as we handle mmlu task group as string - target_for_fewshot_sorting = [choices[gold_index]] if task_name and "mmlu" in task_name else choices[gold_index] - return Doc( task_name=task_name, query=query, choices=choices, gold_index=gold_index, instruction=instruction, - target_for_fewshot_sorting=target_for_fewshot_sorting, ) diff --git a/examples/model_configs/endpoint_model.yaml b/examples/model_configs/endpoint_model.yaml index 2834cdd2..4bf2f060 100644 --- a/examples/model_configs/endpoint_model.yaml +++ b/examples/model_configs/endpoint_model.yaml @@ -10,8 +10,8 @@ model: accelerator: "gpu" region: "eu-west-1" vendor: "aws" - instance_size: "medium" - instance_type: "g5.2xlarge" + instance_size: "x1" + instance_type: "nvidia-a10g" framework: "pytorch" endpoint_type: "protected" namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace diff --git a/examples/nanotron/custom_evaluation_tasks.py b/examples/nanotron/custom_evaluation_tasks.py index 6d4edd62..9ae06671 100644 --- a/examples/nanotron/custom_evaluation_tasks.py +++ b/examples/nanotron/custom_evaluation_tasks.py @@ -333,7 +333,6 @@ def mmlu_harness(line, task_name: str = None): task_name=task_name, query=prompt, choices=[" A", " B", " C", " D"], - target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], gold_index=gold_ix, instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", ) diff --git a/examples/nanotron/custom_task.py b/examples/nanotron/custom_task.py index 49332321..05cea969 100644 --- a/examples/nanotron/custom_task.py +++ b/examples/nanotron/custom_task.py @@ -36,7 +36,7 @@ def mmlu_signs(line, topic): return { "query": prompt, "choices": [" +", " *", " =", " #"] if is_few_shots else ["+", "*", "=", "#"], - "target_for_fewshot_sorting": [" +", " *", " =", " #"][gold_ix], + "fewshot_sorting_class": [" +", " *", " =", " #"][gold_ix], "gold_index": gold_ix, "instruction": f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", } @@ -58,7 +58,7 @@ def mmlu_numbers(line, topic): return { "query": prompt, "choices": [" 1", " 2", " 3", " 4"] if is_few_shots else ["1", "2", "3", "4"], - "target_for_fewshot_sorting": [" 1", " 2", " 3", " 4"][gold_ix], + "fewshot_sorting_class": [" 1", " 2", " 3", " 4"][gold_ix], "gold_index": gold_ix, "instruction": f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", } diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 88de9ade..0305b073 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -20,7 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Iterator +import math +from typing import Iterator, Tuple import torch from torch.utils.data import Dataset @@ -80,7 +81,7 @@ def init_split_limits(self, num_dataset_splits): ) num_dataset_splits = 1 - split_size = self.total_size // num_dataset_splits + 1 + split_size = math.ceil(self.total_size / num_dataset_splits) splits_indices = [ (ix * split_size, min((ix + 1) * split_size, self.total_size)) for ix in range(num_dataset_splits) ] @@ -110,7 +111,7 @@ def get_original_order(self, new_arr: list) -> list: return original_order - def get_split_start_end(self, split_id: int) -> tuple[int, int]: + def get_split_start_end(self, split_id: int) -> Tuple[int, int]: """ Get the start and end indices of a dataset split. @@ -123,7 +124,7 @@ def get_split_start_end(self, split_id: int) -> tuple[int, int]: self.split_start, self.split_end = self.splits[split_id] return self.split_start, self.split_end - def splits_start_end_iterator(self) -> tuple[int, int]: + def splits_start_end_iterator(self) -> Iterator[Tuple[int, int]]: """ Iterator that yields the start and end indices of each dataset split. Also updates the starting batch size for each split (trying to double @@ -132,7 +133,10 @@ def splits_start_end_iterator(self) -> tuple[int, int]: Yields: tuple: A tuple containing the start and end indices of a split. """ - for split_id in range(self.num_dataset_splits): + split_range = self.num_dataset_splits + if self.total_size == 0: + split_range = 0 + for split_id in range(split_range): yield self.get_split_start_end(split_id) def __getitem__(self, index) -> Request: @@ -247,7 +251,8 @@ def init_split_limits(self, num_dataset_splits): "You cannot select the number of dataset splits for a generative evaluation at the moment. Automatically inferring." ) - all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]] + if len(self.sorted_data) > 0: + all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]] splits_indices = [[0, None]] for ix, req in enumerate(self.sorted_data): current_sorting_criteria = self._sorting_criteria(req) diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index 4303ad4b..5b6a3312 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -176,7 +176,6 @@ def bbh_harness(line, task_name: str = None): query=query, choices=choices, gold_index=correct_index, - target_for_fewshot_sorting=choices, instruction=line.get("task_prefix", None), ) @@ -196,7 +195,6 @@ def bbh_lighteval(line, task_name: str = None): query=query, choices=LETTER_INDICES[: len(line["choices"])], gold_index=line["target_idx"], - target_for_fewshot_sorting=LETTER_INDICES[: len(line["choices"])], instruction=line.get("task_prefix", None), ) @@ -207,7 +205,6 @@ def bbh(line, instruction, choices, task_name: str = None): query=f"{instruction}Q: {line['input']}\nA:", choices=choices, gold_index=choices.index(line["target"]), - target_for_fewshot_sorting=[f" {c}" for c in choices], instruction=instruction, ) @@ -799,7 +796,6 @@ def hellaswag_generative(line, task_name: str = None): choices=[" " + i for i in LETTER_INDICES[: len(line["endings"])]], gold_index=gold_ix, # -1 for test, instruction="The following are multiple choice questions (with answers) about common sense.\n\n", - target_for_fewshot_sorting=line["endings"][gold_ix] if gold_ix > -1 else "", ) @@ -1352,7 +1348,6 @@ def mmlu(line, topic, task_name: str = None): choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"], gold_index=gold_ix, instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], ) @@ -1373,7 +1368,6 @@ def custom_mmlu_thom(line, task_name: str = None): choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"], gold_index=gold_ix, instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], ) @@ -1613,7 +1607,6 @@ def mmlu_harness(line, task_name: str = None): query += "Answer:" gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - "__few_shots" in line and line["__few_shots"] is True # We are adding few shots return Doc( task_name=task_name, @@ -1621,7 +1614,6 @@ def mmlu_harness(line, task_name: str = None): choices=[" A", " B", " C", " D"], gold_index=gold_ix, instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], ) @@ -1638,8 +1630,8 @@ def mmlu_helm(line, task_name: str = None): query=query, choices=[" A", " B", " C", " D"], gold_index=gold_ix, + fewshot_sorting_class=line["choices"][gold_ix], instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\n", - target_for_fewshot_sorting=line["choices"][gold_ix], # specific to HELM evals ) @@ -1816,7 +1808,6 @@ def openbookqa_helm(line, task_name: str = None): choices=["A", "B", "C", "D", "E"], gold_index=gold_ix, instruction="The following are multiple choice questions (with answers) about common sense.\n", - target_for_fewshot_sorting=line["choices"]["text"][gold_ix], # specific to HELM evals ) @@ -1837,14 +1828,13 @@ def piqa_helm(line, task_name: str = None): query += "Answer: " gold_ix = int(line["label"]) - + is_few_shots = line.get("__few_shots", False) return Doc( task_name=task_name, query=query, - choices=["A", "B"], + choices=["A", "B"] if not is_few_shots else [line["sol1"], line["sol2"]], gold_index=gold_ix, instruction="The following are multiple choice questions (with answers) about common sense.\n", - target_for_fewshot_sorting=[line["sol1"], line["sol2"]][gold_ix], ) @@ -1877,13 +1867,11 @@ def pubmed_qa_helm(line, task_name: str = None): ) query += f"\n\nQuestion: {line['question']}\nAnswer: " gold_ix = ["yes", "no", "maybe"].index(line["final_decision"]) - return Doc( task_name=task_name, query=query, choices=["A", "B", "C"], gold_index=gold_ix, - target_for_fewshot_sorting=["yes", "no", "maybe"][gold_ix], ) @@ -2263,13 +2251,11 @@ def truthful_qa_helm(line, task_name: str = None): query = f"Question: {line['question']}\n" query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) query += "Answer:" - return Doc( task_name=task_name, query=query, choices=LETTER_INDICES[: len(line["choices"])], gold_index=line["gold_index"], - target_for_fewshot_sorting=line["choices"][line["gold_index"]], ) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index b119fa21..cba69457 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -289,11 +289,13 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: docs = [] for split in splits: - for item in self.dataset[split]: + for ix, item in enumerate(self.dataset[split]): # Some tasks formatting is applied differently when the document is used for fewshot examples # vs when it's used for the actual prompt. That's why we store whether we are currently using the # doc for a fewshot sample (few_shots=True) or not, which then leads to the creation of a different Doc. item["__few_shots"] = few_shots + # Some tasks require to know which is the current item index in order to apply a different prompt template + item["__index"] = ix cur_docs = self.formatter(item, self.name) if cur_docs is None: continue @@ -340,21 +342,6 @@ def eval_docs(self) -> list[Doc]: self._docs = self.remove_duplicate_docs(self._docs) return self._docs - def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str: - """ - Returns the target of the given document. - - Args: - formatted_doc (Doc): Formatted document. - few_shot (bool, optional): Whether the document is used for few - shot examples. Defaults to False. - - Returns: - str: Target of the document, which is the correct answer for a document. - """ - # likely we mostly need one example not all - return as_list(formatted_doc.get_golds(few_shot=few_shot))[0] - def construct_requests( self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str ) -> Dict[RequestType, List[Request]]: diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py index c9fe872b..7555b72a 100644 --- a/src/lighteval/tasks/prompt_manager.py +++ b/src/lighteval/tasks/prompt_manager.py @@ -65,20 +65,33 @@ def doc_to_text(doc: Doc, return_instructions: bool = False) -> Union[str, Tuple ) @staticmethod - def doc_to_target(formatted_doc: Doc, few_shot: bool = False) -> str: + def doc_to_target(formatted_doc: Doc) -> str: """ Returns the target of the given document. Args: formatted_doc (Doc): Formatted document. - few_shot (bool, optional): Whether the document is used for few - shot examples. Defaults to False. Returns: str: Target of the document, which is the correct answer for a document. """ - # likely we mostly need one example not all - return as_list(formatted_doc.get_golds(few_shot=few_shot))[0] + return as_list(formatted_doc.get_golds())[0] + + @staticmethod + def doc_to_fewshot_sorting_class(formatted_doc: Doc) -> str: + """ + In some cases, when selecting few-shot samples, we want to use specific document classes + which need to be specified separately from the target. + For example, a document where the gold is a json might want to use only one of the keys of + the json to define sorting classes in few shot samples. Else we take the gold. + + Args: + formatted_doc (Doc): Formatted document. + + Returns: + str: Class of the + """ + return formatted_doc.fewshot_sorting_class or PromptManager.doc_to_target(formatted_doc) def add_context_to_doc( self, @@ -255,9 +268,7 @@ def get_examples( class FewShotSelectionMethod: sorting: str # sorting method for the overall few shot pool (balanced, random, sequential) with_sampling: bool # samples item randomly from the few shot pool - fewshotpool_unique: ( - bool - ) # set to true if you are CERTAIN there is no intersection between the few shot pool and your evaluation set + fewshotpool_unique: bool # set to true if you are CERTAIN there is no intersection between the few shot pool and your evaluation set class FewShotSelection(Enum): @@ -356,16 +367,16 @@ def _init_fewshot_sampling_balanced( ): fewshotpool = self.task.fewshot_docs() - # rnd = random.Random(variance_seed) random.seed(variance_seed) - # Build up balanced selection based on labels - # Sort by counts of labels + # Build up balanced selection based on fewshot_sorting_class + # (or the gold target, if the class is undefined) label_to_instances = defaultdict(list) for instance in fewshotpool: - target = PromptManager.doc_to_target(instance, few_shot=True) + target = PromptManager.doc_to_fewshot_sorting_class(instance) label_to_instances[target].append(instance) + # Sort by counts of class labels counts_to_labels = defaultdict(list) for label, instances in sorted(label_to_instances.items()): counts_to_labels[len(instances)].append(label) diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index fb1a8579..cd75ad40 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -178,7 +178,7 @@ class Doc: # For few-shot instruction: Optional[str] = "" - target_for_fewshot_sorting: Optional[str] = None # will probably have to be removed in the future + fewshot_sorting_class: Optional[str] = None # class to use to select balanced few-shot samples # Filled when parsing and adding the few-shot context ctx: Optional[str] = "" @@ -194,18 +194,12 @@ def __post_init__(self): if self.instruction is None: self.instruction = "" - def get_golds(self, few_shot: bool = False): + def get_golds(self): """Return gold targets extracted from the target dict""" gold_indices = as_list(self.gold_index) - if few_shot and self.target_for_fewshot_sorting is not None: - choices = self.target_for_fewshot_sorting - if isinstance(choices, str): # correct choice is already selected - return choices - else: - choices = self.choices golds = [] for gold_ix in gold_indices: - golds.extend(as_list(choices[gold_ix])) + golds.extend(as_list(self.choices[gold_ix])) return golds def __repr__(self): diff --git a/tests/models/test_base_model.py b/tests/models/test_base_model.py new file mode 100644 index 00000000..dac396f5 --- /dev/null +++ b/tests/models/test_base_model.py @@ -0,0 +1,37 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from lighteval.models.base_model import BaseModel +from lighteval.models.model_config import BaseModelConfig +from lighteval.models.model_loader import load_model +from lighteval.utils.utils import EnvConfig + + +def test_empty_requests(): + model_config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") + model: BaseModel = load_model(config=model_config, env_config=EnvConfig(cache_dir=".")) + + assert model.loglikelihood([]) == [] + assert model.loglikelihood_single_token([]) == [] + assert model.loglikelihood_rolling([]) == [] + assert model.greedy_until([]) == [] + assert model.greedy_until_multi_turn([]) == [] diff --git a/tests/reference_scores/harness_prompts.json b/tests/reference_scores/harness_prompts.json index b79a6637..6cd942ef 100644 --- a/tests/reference_scores/harness_prompts.json +++ b/tests/reference_scores/harness_prompts.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:02a5551e1137c799c9a1535112d221c7a77fd07b72c2b38b640164be7ea70828 -size 20246141 +oid sha256:4d7055452bb1f282b8b2c040a3a30856f51aa8d44fe80e2c391cbbc375a19b95 +size 20244716