Skip to content

Commit

Permalink
Added tests for prompt recipe (#45)
Browse files Browse the repository at this point in the history
* added test for jinja recipe

* prettify, fixed warning
  • Loading branch information
soldni authored Jan 12, 2023
1 parent e4ef335 commit 679401d
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 262 deletions.
4 changes: 2 additions & 2 deletions examples/zero_shot_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
self.recipe = smashed.recipes.JinjaRecipe(
tokenizer=self.tokenizer,
jinja_template=template,
max_source_content_length=max_source_content_length,
max_target_content_length=max_target_content_length,
max_source_length_per_shot=max_source_content_length,
max_target_length_per_shot=max_target_content_length,
) >> smashed.recipes.CollatorRecipe(
tokenizer=self.tokenizer,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.15.3"
version = "0.15.4"
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "[email protected]" },
Expand Down
1 change: 0 additions & 1 deletion src/smashed/mappers/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def _find_truncated_lens_longest(
redistributed_extra_len = cls._find_truncated_lens_uniform(
lens=longer_than_average,
max_len=extra_len_to_redistribute,
# max_length=max_len,
)

# we figure out new lengths by adding the redistributed extra length
Expand Down
66 changes: 46 additions & 20 deletions src/smashed/recipes/promptsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
jinja_template: str,
num_shots: int = 0,
max_source_content_length: Optional[int] = None,
max_target_content_length: Optional[int] = None,
max_source_length_per_shot: Optional[int] = None,
max_target_length_per_shot: Optional[int] = None,
truncation_strategy: Literal["longest", "uniform"] = "longest",
use_words: bool = True,
source_fields: Optional[Sequence[str]] = None,
Expand All @@ -35,14 +35,15 @@ def __init__(
the source and target; we use promptsource to parse the
template and extract the source and target fields; please
see the promptsource documentation for more details.
max_source_content_length (Optional[int], optional): the maximum
length of the source content (i.e., the content that is given
as input to the model). If not provided, no truncation will
be performed. Defaults to None.
max_source_length_per_shot (Optional[int], optional): the maximum
length of all the fields that are part of the source in a
prompting shot. If not provided, no truncation will be
performed. Defaults to None
max_target_content_length (Optional[int], optional): the maximum
length of the target content (i.e., the content that is
expected as output from the model). If not provided, no
truncation will be performed. Defaults to None.
length of all the fields that are part of the target in a
prompting shot (that is, the text the model is asked to
generate). If not provided, no truncation will be performed.
Defaults to None.
truncation_strategy ("longest" or "uniform"], optional): how to
perform truncation if the source or target content is longer
than the maximum length. If "longest", the longest fields
Expand Down Expand Up @@ -124,17 +125,42 @@ def __init__(
# if we don't use words, we just use the length of the prompt
# in characters.
length_src_prompt = len(source_text)
length_tgt_prompt = len(target_text)
# for target, we actually take the max in case there are multiple,
# and 0 if there are none.
length_tgt_prompt = max([len(t) for t in target_text] or [0])

# one liner to round to ceil. avoid import of math.ceil
def ceil(x):
return int(x + (1 if x % 1 else 0)) # noqa: E731

if max_source_content_length is not None:
if max_source_length_per_shot is not None:
# in case a max length for the source is provided, we need to
# truncate; first, we decrease the max length by the length of
# prompt text.
max_source_content_length -= length_src_prompt
# truncate. The total max_length for source data in each shot
# needs to be reduce by (a) the length of the target prompt
# text when doing few-shot, and (b) the length of text of
# the prompt.
#
# For both (a) and (b), we need to distribute the length by
# the number of shorts:
# (a): recall that each prompt will contain n shots + the
# prompt for the sequence we care about. So when doing
# n shot, we are adding n target sequences, but are
# truncating n + 1 target sequences. Therefore, we multiply
# target length by n but divide by (n + 1)
# (b): the text that is part of the prompt but is not variables
# (e.g., instructions) must be divided over n + 1 sources.
actual_source_context_length = (
max_source_length_per_shot
- ceil(
(max_target_length_per_shot or 0)
* (num_shots / (num_shots + 1))
)
- ceil(length_src_prompt / (num_shots + 1))
)

# we raise if the max length is less than one after accounting
# for the length of the prompt text.
if max_source_content_length < 1:
if actual_source_context_length < 1:
raise ValueError(
f"max_source_content_length must be at least equal to "
f"the length of the source prompt ({length_src_prompt})!"
Expand All @@ -144,17 +170,17 @@ def __init__(
self.chain(
TruncateMultipleFieldsMapper(
fields_to_truncate=source_fields,
max_length=max_source_content_length,
max_length=actual_source_context_length,
strategy=truncation_strategy,
)
)

if len(target_text) > 0 and max_target_content_length:
if len(target_text) > 0 and max_target_length_per_shot:
# we operate here in the same way as for the source, but we
# only do it if there is a target prompt.
max_target_content_length -= length_tgt_prompt
max_target_length_per_shot -= length_tgt_prompt

if max_target_content_length < 1:
if max_target_length_per_shot < 1:
raise ValueError(
f"max_target_content_length must be at least equal to "
f"the length of the target prompt ({length_tgt_prompt})!"
Expand All @@ -163,7 +189,7 @@ def __init__(
self.chain(
TruncateMultipleFieldsMapper(
fields_to_truncate=target_fields,
max_length=max_target_content_length,
max_length=max_target_length_per_shot,
strategy=truncation_strategy,
)
)
Expand Down
151 changes: 52 additions & 99 deletions tests/test_promptsource.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
import unittest

from transformers.models.auto import AutoTokenizer

from smashed.mappers.promptsource import (
FewShotJinjaMapper,
JinjaMapper,
PromptsourceMapper,
SingleTransformPromptsourceMixin,
)
from smashed.recipes.promptsource import JinjaRecipe

FEW_SHOT_DATASET = [
{
"question": "Who is Bill Gates?",
"answer": "Bill Gates is a billionaire.",
},
{
"question": "who is john lennon?",
"answer": "John Lennon was a musician.",
},
{
"question": "who is john doe?",
"answer": "John Doe is a fictional character.",
},
{
"question": "who is goldie hawn?",
"answer": "Goldie Hawn is an actress.",
},
{
"question": "who is ru paul?",
"answer": "Ru Paul is a drag queen.",
},
]

FEW_SHOT_PROMPT = (
"{% for shot in __shots__ %}"
"Q: {{shot.question}}\n"
"A: {{shot.answer}}\n"
"\n"
"{% endfor %}"
"Q: {{question}}\n"
"A: </s>|||{{answer}}"
)


class TestPromptsource(unittest.TestCase):
Expand Down Expand Up @@ -60,123 +90,46 @@ def test_dataset_prompt_source_mapper(self):
mapped_dataset2 = mapper2.map(dataset, remove_columns=True)
self.assertEqual(mapped_dataset, mapped_dataset2)

def test_promptsource_recipe(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

recipe = JinjaRecipe(
tokenizer=AutoTokenizer.from_pretrained("bert-base-cased"),
jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}",
max_source_content_length=15,
max_target_content_length=5,
)
dataset = [
{
"question": "What is the capital of France?",
"context": "Paris is the capital of " + ("France " * 10),
"answer": "Paris " * 10,
}
]

mapped_dataset, *_ = recipe.map(dataset)

self.assertEqual(
tokenizer.decode(mapped_dataset["input_ids"]),
(
"Q : What is the capital of France? "
"C : Paris is the capital of France "
"A :"
),
)

self.assertEqual(
tokenizer.decode(mapped_dataset["labels"]),
"Paris Paris Paris Paris Paris",
)

def _few_shot_data_prompt(self):
dataset = [
{
"question": "Who is Bill Gates?",
"answer": "Bill Gates is a billionaire.",
},
{
"question": "who is john lennon?",
"answer": "John Lennon was a musician.",
},
{
"question": "who is john doe?",
"answer": "John Doe is a fictional character.",
},
{
"question": "who is goldie hawn?",
"answer": "Goldie Hawn is an actress.",
},
{
"question": "who is ru paul?",
"answer": "Ru Paul is a drag queen.",
},
]
jinja_prompt = (
"{% for shot in __shots__ %}"
"Q: {{shot.question}}\n"
"A: {{shot.answer}}\n"
"\n"
"{% endfor %}"
"Q: {{question}}\n"
"A: </s>|||{{answer}}"
)

return dataset, jinja_prompt

def test_fewshot_jinja(self):
mapper = FewShotJinjaMapper(jinja=FEW_SHOT_PROMPT, num_shots=2)

dataset, jinja_prompt = self._few_shot_data_prompt()

mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=2)

mapped_dataset = mapper.map(dataset)
mapped_dataset = mapper.map(FEW_SHOT_DATASET)

self.assertEqual(len(mapped_dataset), 1)

self.assertEqual(
mapped_dataset[0]["source"],
(
"Q: Who is Bill Gates?\nA: Bill Gates is a billionaire.\n\n"
"Q: who is john lennon?\nA: John Lennon was a musician.\n\n"
"Q: who is john doe?\nA: </s>"
f"Q: {FEW_SHOT_DATASET[0]['question']}\n"
f"A: {FEW_SHOT_DATASET[0]['answer']}\n\n"
f"Q: {FEW_SHOT_DATASET[1]['question']}\n"
f"A: {FEW_SHOT_DATASET[1]['answer']}\n\n"
f"Q: {FEW_SHOT_DATASET[2]['question']}\nA: </s>"
),
)

self.assertEqual(
mapped_dataset[0]["target"],
"John Doe is a fictional character.",
FEW_SHOT_DATASET[2]["answer"],
)

def test_few_shot_jinja_zero_shots(self):
dataset, jinja_prompt = self._few_shot_data_prompt()
mapper = FewShotJinjaMapper(jinja=FEW_SHOT_PROMPT, num_shots=0)

mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=0)

mapped_dataset = mapper.map(dataset)
mapped_dataset = mapper.map(FEW_SHOT_DATASET)

self.assertEqual(len(mapped_dataset), 5)

self.assertEqual(
mapped_dataset[0]["source"], "Q: Who is Bill Gates?\nA: </s>"
)

self.assertEqual(
mapped_dataset[0]["target"],
"Bill Gates is a billionaire.",
)
for i in range(5):
self.assertEqual(
mapped_dataset[i]["source"],
f"Q: {FEW_SHOT_DATASET[i]['question']}\nA: </s>",
)

self.assertEqual(
mapped_dataset[1]["source"], "Q: who is john lennon?\nA: </s>"
)
self.assertEqual(
mapped_dataset[1]["target"],
"John Lennon was a musician.",
)
self.assertEqual(
mapped_dataset[i]["target"],
FEW_SHOT_DATASET[i]["answer"],
)

def test_few_shot_exception(self):
with self.assertRaises(KeyError):
Expand Down
Loading

0 comments on commit 679401d

Please sign in to comment.