diff --git a/examples/zero_shot_prompting.py b/examples/zero_shot_prompting.py index 3abb76b..3d86ac0 100644 --- a/examples/zero_shot_prompting.py +++ b/examples/zero_shot_prompting.py @@ -38,8 +38,8 @@ def __init__( self.recipe = smashed.recipes.JinjaRecipe( tokenizer=self.tokenizer, jinja_template=template, - max_source_content_length=max_source_content_length, - max_target_content_length=max_target_content_length, + max_source_length_per_shot=max_source_content_length, + max_target_length_per_shot=max_target_content_length, ) >> smashed.recipes.CollatorRecipe( tokenizer=self.tokenizer, device=device, diff --git a/pyproject.toml b/pyproject.toml index d253d0c..d890e26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.15.3" + version = "0.15.4" description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries." authors = [ {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" }, diff --git a/src/smashed/mappers/prompting.py b/src/smashed/mappers/prompting.py index 3987bbc..472e739 100644 --- a/src/smashed/mappers/prompting.py +++ b/src/smashed/mappers/prompting.py @@ -245,7 +245,6 @@ def _find_truncated_lens_longest( redistributed_extra_len = cls._find_truncated_lens_uniform( lens=longer_than_average, max_len=extra_len_to_redistribute, - # max_length=max_len, ) # we figure out new lengths by adding the redistributed extra length diff --git a/src/smashed/recipes/promptsource.py b/src/smashed/recipes/promptsource.py index 6ef2ea6..f2f453c 100644 --- a/src/smashed/recipes/promptsource.py +++ b/src/smashed/recipes/promptsource.py @@ -17,8 +17,8 @@ def __init__( tokenizer: PreTrainedTokenizerBase, jinja_template: str, num_shots: int = 0, - max_source_content_length: Optional[int] = None, - max_target_content_length: Optional[int] = None, + max_source_length_per_shot: Optional[int] = None, + max_target_length_per_shot: Optional[int] = None, truncation_strategy: Literal["longest", "uniform"] = "longest", use_words: bool = True, source_fields: Optional[Sequence[str]] = None, @@ -35,14 +35,15 @@ def __init__( the source and target; we use promptsource to parse the template and extract the source and target fields; please see the promptsource documentation for more details. - max_source_content_length (Optional[int], optional): the maximum - length of the source content (i.e., the content that is given - as input to the model). If not provided, no truncation will - be performed. Defaults to None. + max_source_length_per_shot (Optional[int], optional): the maximum + length of all the fields that are part of the source in a + prompting shot. If not provided, no truncation will be + performed. Defaults to None max_target_content_length (Optional[int], optional): the maximum - length of the target content (i.e., the content that is - expected as output from the model). If not provided, no - truncation will be performed. Defaults to None. + length of all the fields that are part of the target in a + prompting shot (that is, the text the model is asked to + generate). If not provided, no truncation will be performed. + Defaults to None. truncation_strategy ("longest" or "uniform"], optional): how to perform truncation if the source or target content is longer than the maximum length. If "longest", the longest fields @@ -124,17 +125,42 @@ def __init__( # if we don't use words, we just use the length of the prompt # in characters. length_src_prompt = len(source_text) - length_tgt_prompt = len(target_text) + # for target, we actually take the max in case there are multiple, + # and 0 if there are none. + length_tgt_prompt = max([len(t) for t in target_text] or [0]) + + # one liner to round to ceil. avoid import of math.ceil + def ceil(x): + return int(x + (1 if x % 1 else 0)) # noqa: E731 - if max_source_content_length is not None: + if max_source_length_per_shot is not None: # in case a max length for the source is provided, we need to - # truncate; first, we decrease the max length by the length of - # prompt text. - max_source_content_length -= length_src_prompt + # truncate. The total max_length for source data in each shot + # needs to be reduce by (a) the length of the target prompt + # text when doing few-shot, and (b) the length of text of + # the prompt. + # + # For both (a) and (b), we need to distribute the length by + # the number of shorts: + # (a): recall that each prompt will contain n shots + the + # prompt for the sequence we care about. So when doing + # n shot, we are adding n target sequences, but are + # truncating n + 1 target sequences. Therefore, we multiply + # target length by n but divide by (n + 1) + # (b): the text that is part of the prompt but is not variables + # (e.g., instructions) must be divided over n + 1 sources. + actual_source_context_length = ( + max_source_length_per_shot + - ceil( + (max_target_length_per_shot or 0) + * (num_shots / (num_shots + 1)) + ) + - ceil(length_src_prompt / (num_shots + 1)) + ) # we raise if the max length is less than one after accounting # for the length of the prompt text. - if max_source_content_length < 1: + if actual_source_context_length < 1: raise ValueError( f"max_source_content_length must be at least equal to " f"the length of the source prompt ({length_src_prompt})!" @@ -144,17 +170,17 @@ def __init__( self.chain( TruncateMultipleFieldsMapper( fields_to_truncate=source_fields, - max_length=max_source_content_length, + max_length=actual_source_context_length, strategy=truncation_strategy, ) ) - if len(target_text) > 0 and max_target_content_length: + if len(target_text) > 0 and max_target_length_per_shot: # we operate here in the same way as for the source, but we # only do it if there is a target prompt. - max_target_content_length -= length_tgt_prompt + max_target_length_per_shot -= length_tgt_prompt - if max_target_content_length < 1: + if max_target_length_per_shot < 1: raise ValueError( f"max_target_content_length must be at least equal to " f"the length of the target prompt ({length_tgt_prompt})!" @@ -163,7 +189,7 @@ def __init__( self.chain( TruncateMultipleFieldsMapper( fields_to_truncate=target_fields, - max_length=max_target_content_length, + max_length=max_target_length_per_shot, strategy=truncation_strategy, ) ) diff --git a/tests/test_promptsource.py b/tests/test_promptsource.py index 9378303..a8211d0 100644 --- a/tests/test_promptsource.py +++ b/tests/test_promptsource.py @@ -1,14 +1,44 @@ import unittest -from transformers.models.auto import AutoTokenizer - from smashed.mappers.promptsource import ( FewShotJinjaMapper, JinjaMapper, PromptsourceMapper, SingleTransformPromptsourceMixin, ) -from smashed.recipes.promptsource import JinjaRecipe + +FEW_SHOT_DATASET = [ + { + "question": "Who is Bill Gates?", + "answer": "Bill Gates is a billionaire.", + }, + { + "question": "who is john lennon?", + "answer": "John Lennon was a musician.", + }, + { + "question": "who is john doe?", + "answer": "John Doe is a fictional character.", + }, + { + "question": "who is goldie hawn?", + "answer": "Goldie Hawn is an actress.", + }, + { + "question": "who is ru paul?", + "answer": "Ru Paul is a drag queen.", + }, +] + +FEW_SHOT_PROMPT = ( + "{% for shot in __shots__ %}" + "Q: {{shot.question}}\n" + "A: {{shot.answer}}\n" + "\n" + "{% endfor %}" + "Q: {{question}}\n" + "A: |||{{answer}}" +) class TestPromptsource(unittest.TestCase): @@ -60,123 +90,46 @@ def test_dataset_prompt_source_mapper(self): mapped_dataset2 = mapper2.map(dataset, remove_columns=True) self.assertEqual(mapped_dataset, mapped_dataset2) - def test_promptsource_recipe(self): - tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") - - recipe = JinjaRecipe( - tokenizer=AutoTokenizer.from_pretrained("bert-base-cased"), - jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}", - max_source_content_length=15, - max_target_content_length=5, - ) - dataset = [ - { - "question": "What is the capital of France?", - "context": "Paris is the capital of " + ("France " * 10), - "answer": "Paris " * 10, - } - ] - - mapped_dataset, *_ = recipe.map(dataset) - - self.assertEqual( - tokenizer.decode(mapped_dataset["input_ids"]), - ( - "Q : What is the capital of France? " - "C : Paris is the capital of France " - "A :" - ), - ) - - self.assertEqual( - tokenizer.decode(mapped_dataset["labels"]), - "Paris Paris Paris Paris Paris", - ) - - def _few_shot_data_prompt(self): - dataset = [ - { - "question": "Who is Bill Gates?", - "answer": "Bill Gates is a billionaire.", - }, - { - "question": "who is john lennon?", - "answer": "John Lennon was a musician.", - }, - { - "question": "who is john doe?", - "answer": "John Doe is a fictional character.", - }, - { - "question": "who is goldie hawn?", - "answer": "Goldie Hawn is an actress.", - }, - { - "question": "who is ru paul?", - "answer": "Ru Paul is a drag queen.", - }, - ] - jinja_prompt = ( - "{% for shot in __shots__ %}" - "Q: {{shot.question}}\n" - "A: {{shot.answer}}\n" - "\n" - "{% endfor %}" - "Q: {{question}}\n" - "A: |||{{answer}}" - ) - - return dataset, jinja_prompt - def test_fewshot_jinja(self): + mapper = FewShotJinjaMapper(jinja=FEW_SHOT_PROMPT, num_shots=2) - dataset, jinja_prompt = self._few_shot_data_prompt() - - mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=2) - - mapped_dataset = mapper.map(dataset) + mapped_dataset = mapper.map(FEW_SHOT_DATASET) self.assertEqual(len(mapped_dataset), 1) self.assertEqual( mapped_dataset[0]["source"], ( - "Q: Who is Bill Gates?\nA: Bill Gates is a billionaire.\n\n" - "Q: who is john lennon?\nA: John Lennon was a musician.\n\n" - "Q: who is john doe?\nA: " + f"Q: {FEW_SHOT_DATASET[0]['question']}\n" + f"A: {FEW_SHOT_DATASET[0]['answer']}\n\n" + f"Q: {FEW_SHOT_DATASET[1]['question']}\n" + f"A: {FEW_SHOT_DATASET[1]['answer']}\n\n" + f"Q: {FEW_SHOT_DATASET[2]['question']}\nA: " ), ) self.assertEqual( mapped_dataset[0]["target"], - "John Doe is a fictional character.", + FEW_SHOT_DATASET[2]["answer"], ) def test_few_shot_jinja_zero_shots(self): - dataset, jinja_prompt = self._few_shot_data_prompt() + mapper = FewShotJinjaMapper(jinja=FEW_SHOT_PROMPT, num_shots=0) - mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=0) - - mapped_dataset = mapper.map(dataset) + mapped_dataset = mapper.map(FEW_SHOT_DATASET) self.assertEqual(len(mapped_dataset), 5) - self.assertEqual( - mapped_dataset[0]["source"], "Q: Who is Bill Gates?\nA: " - ) - - self.assertEqual( - mapped_dataset[0]["target"], - "Bill Gates is a billionaire.", - ) + for i in range(5): + self.assertEqual( + mapped_dataset[i]["source"], + f"Q: {FEW_SHOT_DATASET[i]['question']}\nA: ", + ) - self.assertEqual( - mapped_dataset[1]["source"], "Q: who is john lennon?\nA: " - ) - self.assertEqual( - mapped_dataset[1]["target"], - "John Lennon was a musician.", - ) + self.assertEqual( + mapped_dataset[i]["target"], + FEW_SHOT_DATASET[i]["answer"], + ) def test_few_shot_exception(self): with self.assertRaises(KeyError): diff --git a/tests/test_promptsource_recipe.py b/tests/test_promptsource_recipe.py index 9378303..93e9636 100644 --- a/tests/test_promptsource_recipe.py +++ b/tests/test_promptsource_recipe.py @@ -2,72 +2,58 @@ from transformers.models.auto import AutoTokenizer -from smashed.mappers.promptsource import ( - FewShotJinjaMapper, - JinjaMapper, - PromptsourceMapper, - SingleTransformPromptsourceMixin, -) from smashed.recipes.promptsource import JinjaRecipe +FEW_SHOT_DATASET = [ + { + "question": "Who is Bill Gates?", + "answer": "Bill Gates is a billionaire.", + }, + { + "question": "who is john lennon?", + "answer": "John Lennon was a musician.", + }, + { + "question": "who is john doe?", + "answer": "John Doe is a fictional character.", + }, + { + "question": "who is goldie hawn?", + "answer": "Goldie Hawn is an actress.", + }, + { + "question": "who is ru paul?", + "answer": "Ru Paul is a drag queen.", + }, + { + "question": "who is john wayne?", + "answer": "John Wayne was an actor; he's dead!", + }, +] + +FEW_SHOT_PROMPT = ( + "{% for shot in __shots__ %}" + "Q: {{shot.question}}\n" + "A: {{shot.answer}}\n" + "\n" + "{% endfor %}" + "Q: {{question}}\n" + "A: |||{{answer}}" +) -class TestPromptsource(unittest.TestCase): - def test_jinja_prompt_source_mapper(self): - mapper = JinjaMapper( - jinja="Q: {{question}}\nA: |||{{answers.text[0]}}" - ) - dataset = [ - { - "question": "What is the capital of France?", - "context": "Paris is the capital of France.", - "answers": {"text": ["Paris"], "answer_start": [0]}, - } - ] - mapped_dataset = mapper.map(dataset, remove_columns=True) - self.assertEqual( - mapped_dataset[0]["source"], - "Q: What is the capital of France?\nA:", - ) - self.assertEqual(mapped_dataset[0]["target"], "Paris") - - def test_dataset_prompt_source_mapper(self): - mapper = PromptsourceMapper( - dataset_name="squad", - template_name="given_context_answer_question_variation", - ) - dataset = [ - { - "question": "What is the capital of France?", - "context": "Paris is the capital of France.", - "answers": {"text": ["Paris"], "answer_start": [0]}, - } - ] - - mapped_dataset = mapper.map(dataset, remove_columns=True) - self.assertEqual(len(mapped_dataset), 1) - self.assertEqual(len(mapped_dataset[0]), 2) - self.assertEqual( - mapped_dataset[0]["source"], - ( - "Paris is the capital of France.\n\n" - "Q: What is the capital of France?\n\nA:" - ), +class TestPromptsource(unittest.TestCase): + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + "t5-small", model_max_length=512 ) - self.assertEqual(mapped_dataset[0]["target"], "Paris") - - mapper2 = SingleTransformPromptsourceMixin(mapper.template) - mapped_dataset2 = mapper2.map(dataset, remove_columns=True) - self.assertEqual(mapped_dataset, mapped_dataset2) def test_promptsource_recipe(self): - tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") - recipe = JinjaRecipe( - tokenizer=AutoTokenizer.from_pretrained("bert-base-cased"), + tokenizer=self.tokenizer, jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}", - max_source_content_length=15, - max_target_content_length=5, + max_source_length_per_shot=15, + max_target_length_per_shot=5, ) dataset = [ { @@ -80,109 +66,74 @@ def test_promptsource_recipe(self): mapped_dataset, *_ = recipe.map(dataset) self.assertEqual( - tokenizer.decode(mapped_dataset["input_ids"]), + self.tokenizer.decode(mapped_dataset["input_ids"]), ( - "Q : What is the capital of France? " - "C : Paris is the capital of France " - "A :" + "Q: What is the capital of France? " + "C: Paris is the capital of France " + "A:" ), ) self.assertEqual( - tokenizer.decode(mapped_dataset["labels"]), + self.tokenizer.decode(mapped_dataset["labels"]), "Paris Paris Paris Paris Paris", ) - def _few_shot_data_prompt(self): - dataset = [ - { - "question": "Who is Bill Gates?", - "answer": "Bill Gates is a billionaire.", - }, - { - "question": "who is john lennon?", - "answer": "John Lennon was a musician.", - }, - { - "question": "who is john doe?", - "answer": "John Doe is a fictional character.", - }, - { - "question": "who is goldie hawn?", - "answer": "Goldie Hawn is an actress.", - }, - { - "question": "who is ru paul?", - "answer": "Ru Paul is a drag queen.", - }, - ] - jinja_prompt = ( - "{% for shot in __shots__ %}" - "Q: {{shot.question}}\n" - "A: {{shot.answer}}\n" - "\n" - "{% endfor %}" - "Q: {{question}}\n" - "A: |||{{answer}}" - ) + def test_few_shot_truncation(self): - return dataset, jinja_prompt - - def test_fewshot_jinja(self): - - dataset, jinja_prompt = self._few_shot_data_prompt() - - mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=2) + mapper = JinjaRecipe( + tokenizer=self.tokenizer, + jinja_template=FEW_SHOT_PROMPT, + num_shots=2, + max_source_length_per_shot=31, + max_target_length_per_shot=14, + use_words=False, + ) - mapped_dataset = mapper.map(dataset) + mapped_dataset = mapper.map(FEW_SHOT_DATASET) - self.assertEqual(len(mapped_dataset), 1) + self.assertEqual(len(mapped_dataset), 2) + # the total non-template length of each prompt is 20, + # which means each should be shortened by ceil(20 / 3) = 7 chars. + # further, we need to account for 5 characters per shot, so that's + # a further ceil((14 * 2) / 3) = 10 chars. So from 31 the effective + # max length should 14 for question. The answers should all get + # truncated to 14 characters. + # + # The fact that the prompt is a bit different from the template + # is totally fine: T5 removes multiple spaces, turns newlines into + # spaces, and decoding strips the trailing spaces. self.assertEqual( - mapped_dataset[0]["source"], + self.tokenizer.decode(mapped_dataset[0]["input_ids"]), ( - "Q: Who is Bill Gates?\nA: Bill Gates is a billionaire.\n\n" - "Q: who is john lennon?\nA: John Lennon was a musician.\n\n" - "Q: who is john doe?\nA: " + f"Q: {FEW_SHOT_DATASET[0]['question'][:14].rstrip()} " + f"A: {FEW_SHOT_DATASET[0]['answer'][:14].rstrip()} " + f"Q: {FEW_SHOT_DATASET[1]['question'][:14].rstrip()} " + f"A: {FEW_SHOT_DATASET[1]['answer'][:14].rstrip()} " + f"Q: {FEW_SHOT_DATASET[2]['question'][:14].rstrip()} " + "A:" ), ) - self.assertEqual( - mapped_dataset[0]["target"], - "John Doe is a fictional character.", - ) - - def test_few_shot_jinja_zero_shots(self): - dataset, jinja_prompt = self._few_shot_data_prompt() - - mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=0) - - mapped_dataset = mapper.map(dataset) - - self.assertEqual(len(mapped_dataset), 5) - - self.assertEqual( - mapped_dataset[0]["source"], "Q: Who is Bill Gates?\nA: " + self.tokenizer.decode(mapped_dataset[0]["labels"]), + FEW_SHOT_DATASET[2]["answer"][:14].rstrip(), ) + # do it for the other few shot sample. self.assertEqual( - mapped_dataset[0]["target"], - "Bill Gates is a billionaire.", + self.tokenizer.decode(mapped_dataset[1]["input_ids"]), + ( + f"Q: {FEW_SHOT_DATASET[3]['question'][:14].rstrip()} " + f"A: {FEW_SHOT_DATASET[3]['answer'][:14].rstrip()} " + f"Q: {FEW_SHOT_DATASET[4]['question'][:14].rstrip()} " + f"A: {FEW_SHOT_DATASET[4]['answer'][:14].rstrip()} " + f"Q: {FEW_SHOT_DATASET[5]['question'][:14].rstrip()} " + "A:" + ), ) self.assertEqual( - mapped_dataset[1]["source"], "Q: who is john lennon?\nA: " - ) - self.assertEqual( - mapped_dataset[1]["target"], - "John Lennon was a musician.", + self.tokenizer.decode(mapped_dataset[1]["labels"]), + FEW_SHOT_DATASET[5]["answer"][:14].rstrip(), ) - - def test_few_shot_exception(self): - with self.assertRaises(KeyError): - FewShotJinjaMapper( - jinja="Q: {{question}}\nA: {{answer}}", num_shots=2 - ) - - with self.assertRaises(ValueError): - FewShotJinjaMapper("{{ __shots__ }}", num_shots=-2)