From 1061cf51ea8d20e8f4af1f24329094abf4a4a071 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 5 Jan 2023 21:01:53 -0800 Subject: [PATCH] Promptsource recipe (#40) * no-op for recipe * no-op for recipe * prompting recipe and test * small error handling * documentation * documentation --- pyproject.toml | 2 +- src/smashed/mappers/promptsource.py | 155 ++++++++++++++++++++++-- src/smashed/recipes/__init__.py | 2 + src/smashed/recipes/collators.py | 15 ++- src/smashed/recipes/promptsource.py | 180 ++++++++++++++++++++++++++++ tests/test_promptsource.py | 36 ++++++ 6 files changed, 377 insertions(+), 13 deletions(-) create mode 100644 src/smashed/recipes/promptsource.py diff --git a/pyproject.toml b/pyproject.toml index 89fa877..be73182 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.13.0" +version = "0.14.0" description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries." authors = [ {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" }, diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 02fbe67..3c3a7a2 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, cast +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple, cast from necessary import Necessary, necessary @@ -30,6 +31,35 @@ def __init__( return_multiple_targets: bool = False, extra_variables: Optional[Dict[str, Any]] = None, ): + """Uses a promptsource template to generate source and target sequence; + in the returned dictionary of samples, the source sequence is stored + under the key `source_field_name` and the target sequence is stored + under the key `target_field_name`. If the template does not contain + the control sequence `|||`, then no target sequence is generated. + Args: + template (promptsource.templates.Template): the promptsource + template to use. + source_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the source + sequence. Defaults to "source". + target_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the target + sequence. Defaults to "target". + truncate (bool, optional): whether to truncate the source and + target sequences to the maximum length allowed by + the promptsource library. Defaults to False. + highlight_variables (bool, optional): whether to highlight the + variables in the source and target sequences with special + html tags. Defaults to False. + return_multiple_targets (bool, optional): whether to return + a list of target sequences for each sample. Defaults to False. + If the template returns multiple targets, but this argument + is set to False, then only the first target is returned. + extra_variables (Optional[Dict[str, Any]], optional): a dictionary + of extra variables that will be passed to the promptsource + template. Defaults to None. + """ + self.template = template self.truncate = truncate self.highlight_vars = highlight_variables @@ -44,23 +74,65 @@ def __init__( # abstract syntax tree for the jinja template; we will use it # to find all fields that are required by the template - ast = Environment().parse(self.template.jinja) - input_fields = sorted( - var_name - for var_name in meta.find_undeclared_variables(ast) - if var_name not in self.extra_vars - ) output_fields = [self.src_fld_name] if "|||" in self.template.jinja: output_fields.append(self.tgt_fld_name) + input_src_fields, input_tgt_fields = self.approximate_input_fields super().__init__( - input_fields=input_fields, output_fields=output_fields + input_fields=set(input_src_fields + input_tgt_fields), + output_fields=output_fields, + ) + + def _approximate_input_fields(self, jinja_txt: str) -> List[str]: + ast = Environment().parse(jinja_txt) + return sorted( + var_name + for var_name in meta.find_undeclared_variables(ast) + if var_name not in self.extra_vars + ) + + @property + def approximate_input_fields(self) -> Tuple[List[str], List[str]]: + """Input fields that are likely to be required by the template; + It is approximate because we ignore nested variables.""" + + source_template, *target_templates = self.template.jinja.split("|||") + source_fields = self._approximate_input_fields(source_template) + target_fields = sorted( + set( + chain.from_iterable( + self._approximate_input_fields(template) + for template in target_templates + ) + ) ) + return source_fields, target_fields + + def _approximate_text_from_template(self, txt: str) -> str: + return "".join(part.split("}}")[-1] for part in txt.split("{{")) + + @property + def approximate_prompt_text(self) -> Tuple[str, List[str]]: + """The prompt without the variables; it is approximate because + we might not be able to remove all variables.""" + + source_template, *target_templates = self.template.jinja.split("|||") + + source_str = self._approximate_text_from_template(source_template) + target_str = [ + self._approximate_text_from_template(template) + for template in target_templates + ] + return source_str, target_str + + @property + def has_target(self) -> bool: + return "|||" in self.template.jinja def __getstate__(self) -> dict: - """We need to serialize the template using yaml so the hash for this + """We need to serialize thve template using yaml so the hash for this mapper is consistent across runs.""" out = super().__getstate__() out["__dict__"]["template"] = yaml.dump(self.template) @@ -113,6 +185,37 @@ def __init__( return_multiple_targets: bool = False, extra_variables: Optional[Dict[str, Any]] = None, ): + """Use one of the existing promptsource templates to generate + source and target sequences for a dataset. See the promptsource + repository for a list of available templates: + https://github.com/bigscience-workshop/promptsource + + Args: + dataset_name (str): the name of the dataset to use. + template_name (str): the name of the template to use. + subset_name (Optional[str], optional): the name of the subset + to use. Defaults to None. + source_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the source + sequence. Defaults to "source". + target_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the target + sequence. Defaults to "target". + truncate (bool, optional): whether to truncate the source and + target sequences to the maximum length allowed by + the promptsource library. Defaults to False. + highlight_variables (bool, optional): whether to highlight the + variables in the source and target sequences with special + html tags. Defaults to False. + return_multiple_targets (bool, optional): whether to return + a list of target sequences for each sample. Defaults to False. + If the template returns multiple targets, but this argument + is set to False, then only the first target is returned. + extra_variables (Optional[Dict[str, Any]], optional): a dictionary + of extra variables that will be passed to the promptsource + template. Defaults to None. + """ + # DatasetTemplates is not well annotated, so though subset_name # is optional, it is annotated as `str`, so we need to cast it. subset_name = cast(str, subset_name) @@ -151,6 +254,40 @@ def __init__( return_multiple_targets: bool = False, extra_variables: Optional[Dict[str, Any]] = None, ): + """Use a custom jinja template to obtain a template from the + promptsource library. See the jinja documentation for a list of + language features and syntax: https://jinja.palletsprojects.com/ + + Args: + jinja (str): the jinja template to use. The template can access + the data in each sample; the name of fields in the datasets + are available as variables in the template. + name (Optional[str], optional): the name of the template. Defaults + to None. + reference (Optional[str], optional): the reference for the + template. Defaults to None. + metadata (Optional["Template.Metadata"], optional): the metadata + for the template. Defaults to None. + source_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the source + sequence. Defaults to "source". + target_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the target + sequence. Defaults to "target". + truncate (bool, optional): whether to truncate the source and + target sequences to the maximum length allowed by + the promptsource library. Defaults to False. + highlight_variables (bool, optional): whether to highlight the + variables in the source and target sequences with special + html tags. Defaults to False. + return_multiple_targets (bool, optional): whether to return + a list of target sequences for each sample. Defaults to False. + If the template returns multiple targets, but this argument + is set to False, then only the first target is returned. + extra_variables (Optional[Dict[str, Any]], optional): a dictionary + of extra variables that will be passed to the promptsource + template. Defaults to None. + """ template = Template( jinja=jinja, name=(name or self.name), diff --git a/src/smashed/recipes/__init__.py b/src/smashed/recipes/__init__.py index d40fc90..34d8226 100644 --- a/src/smashed/recipes/__init__.py +++ b/src/smashed/recipes/__init__.py @@ -1,8 +1,10 @@ from .collators import CollatorRecipe, SlowCollatorRecipe from .prompting import PromptingRecipe +from .promptsource import PromptsourceRecipe __all__ = [ "CollatorRecipe", "PromptingRecipe", + "PromptsourceRecipe", "SlowCollatorRecipe", ] diff --git a/src/smashed/recipes/collators.py b/src/smashed/recipes/collators.py index d856942..e2b48a2 100644 --- a/src/smashed/recipes/collators.py +++ b/src/smashed/recipes/collators.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +import torch from transformers.tokenization_utils_base import PreTrainedTokenizerBase from ..base import BaseRecipe, SingleBaseMapper @@ -46,9 +47,13 @@ def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, List[Any]]: return collated_batch - def get_tensorizer(self) -> Python2TorchMapper: + def get_tensorizer( + self, + field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None, + device: Optional[Union[torch.device, str]] = None, + ) -> Python2TorchMapper: # this turns lists of ints/floats into tensors - return Python2TorchMapper() + return Python2TorchMapper(field_cast_map=field_cast_map, device=device) def get_batcher(self, keep_last: bool) -> FixedBatchSizeMapper: # the collator already receives the "right" number of samples @@ -66,10 +71,14 @@ def __init__( pad_to_length: Optional[Union[int, Sequence[int]]] = None, fields_pad_ids: Optional[Mapping[str, int]] = None, unk_fields_pad_id: Optional[int] = None, + field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None, + device: Optional[Union[torch.device, str]] = None, ) -> None: super().__init__(do_not_collate=do_not_collate) - self.chain(self.get_tensorizer()) + self.chain( + self.get_tensorizer(field_cast_map=field_cast_map, device=device) + ) self.chain(self.get_batcher(keep_last=keep_last)) if tokenizer: diff --git a/src/smashed/recipes/promptsource.py b/src/smashed/recipes/promptsource.py new file mode 100644 index 0000000..e9ab8de --- /dev/null +++ b/src/smashed/recipes/promptsource.py @@ -0,0 +1,180 @@ +from typing import Literal, Optional, Sequence + +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from ..base.recipes import BaseRecipe +from ..mappers.fields import ChangeFieldsMapper +from ..mappers.prompting import TruncateMultipleFieldsMapper +from ..mappers.promptsource import JinjaPromptsourceMapper +from ..mappers.text import TextToWordsMapper, WordsToTextMapper +from ..mappers.tokenize import TokenizerMapper + + +class PromptsourceRecipe(BaseRecipe): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + jinja_template: str, + max_source_content_length: Optional[int] = None, + max_target_content_length: Optional[int] = None, + truncation_strategy: Literal["longest", "uniform"] = "longest", + use_words: bool = True, + additional_fields_to_keep: Optional[Sequence[str]] = None, + ) -> None: + """A recipe for a pipeline that uses promptsource to format data + as source/target pairs for model prompting. + + Args: + tokenizer (PreTrainedTokenizerBase): A tokenizer to use for + tokenizing the source and target. + jinja_template (str): A jinja template to use for formatting + the source and target; we use promptsource to parse the + template and extract the source and target fields; please + see the promptsource documentation for more details. + max_source_content_length (Optional[int], optional): the maximum + length of the source content (i.e., the content that is given + as input to the model). If not provided, no truncation will + be performed. Defaults to None. + max_target_content_length (Optional[int], optional): the maximum + length of the target content (i.e., the content that is + expected as output from the model). If not provided, no + truncation will be performed. Defaults to None. + truncation_strategy ("longest" or "uniform"], optional): how to + perform truncation if the source or target content is longer + than the maximum length. If "longest", the longest fields + specified in the template will be truncated first. If + "uniform", the fields will be truncated uniformly. Defaults + to "longest". + use_words (bool, optional): When truncating, whether to use count + of words or count of characters. Defaults to True, which means + that we use count of words. + additional_fields_to_keep (Optional[Sequence[str]], optional): + After the recipe has been applied, we drop all columns that + are not 'input_ids', 'attention_mask', or 'labels'. If you + want to keep additional columns, you can specify them here. + Defaults to None. + """ + + super().__init__() + + # we instantiate the template mapper early on so we can get the text + # in the prompt that is not variable placeholders; however, we will + # wait till truncation mappers are added to the pipeline before + # instantiating the template mapper. + template_mapper = JinjaPromptsourceMapper(jinja=jinja_template) + src_fields, tgt_fields = template_mapper.approximate_input_fields + src_text, tgt_text = template_mapper.approximate_prompt_text + + if use_words: + # if we we need to first set up a text -> words splitter for + # the fields in the template + text_to_words = TextToWordsMapper( + fields=list(set(src_fields + tgt_fields)) + ) + self.chain(text_to_words) + + # we also need to calculate the lengths in words of the part of + # the prompt that is not content; that way we can subtract it + # from the max content length, for both source and target. + length_src_prompt = len(text_to_words.splitter(src_text)) + + # for target, we actually take the max in case there are multiple + # prompt versions. + length_tgt_prompt = max( + [len(text_to_words.splitter(t)) for t in tgt_text] + # in case tgt_text is empty, we use 0 as a default value + or [0] + ) + else: + # if we don't use words, we just use the length of the prompt + # in characters. + length_src_prompt = len(src_text) + length_tgt_prompt = len(tgt_text) + + if max_source_content_length is not None: + # in case a max length for the source is provided, we need to + # truncate; first, we decrease the max length by the length of + # prompt text. + max_source_content_length -= length_src_prompt + + # we raise if the max length is less than one after accounting + # for the length of the prompt text. + if max_source_content_length < 1: + raise ValueError( + f"max_source_content_length must be at least equal to " + f"the length of the source prompt ({length_src_prompt})!" + ) + + # finally we add a mapper that truncates the source fields. + self.chain( + TruncateMultipleFieldsMapper( + fields_to_truncate=src_fields, + max_length=max_source_content_length, + strategy=truncation_strategy, + ) + ) + + if tgt_text and max_target_content_length: + # we operate here in the same way as for the source, but we + # only do it if there is a target prompt. + max_target_content_length -= length_tgt_prompt + if max_target_content_length < 1: + raise ValueError( + f"max_target_content_length must be at least equal to " + f"the length of the target prompt ({length_tgt_prompt})!" + ) + + self.chain( + TruncateMultipleFieldsMapper( + fields_to_truncate=tgt_fields, + max_length=max_target_content_length, + strategy=truncation_strategy, + ) + ) + + if use_words: + # if we used words, we need to convert the fields back to text + # before filling the template. + self.chain( + WordsToTextMapper(fields=list(set(src_fields + tgt_fields))) + ) + + # we only add the template here because we first need to truncate + # the fields! + self.chain(template_mapper) + + # tokenize source + self.chain( + TokenizerMapper( + tokenizer=tokenizer, + input_field="source", + add_special_tokens=False, + return_attention_mask=True, + truncation=True, + ) + ) + # we need to keep the input_ids and attention_mask fields + # after the recipe has been applied. + keep_fields = ["input_ids", "attention_mask"] + + if template_mapper.has_target: + # tokenize target + self.chain( + TokenizerMapper( + tokenizer=tokenizer, + input_field="target", + output_rename_map={"input_ids": "labels"}, + add_special_tokens=False, + return_attention_mask=False, + truncation=True, + ) + ) + # the target is in the labels field, so we need to keep it. + keep_fields.append("labels") + + if additional_fields_to_keep: + # this is in case the user wants to keep additional fields + keep_fields.extend(additional_fields_to_keep) + + # finally, we do the field filtering. + self.chain(ChangeFieldsMapper(keep_fields=keep_fields)) diff --git a/tests/test_promptsource.py b/tests/test_promptsource.py index afc42b3..ab6d1ba 100644 --- a/tests/test_promptsource.py +++ b/tests/test_promptsource.py @@ -1,10 +1,13 @@ import unittest +from transformers.models.auto import AutoTokenizer + from smashed.mappers.promptsource import ( DatasetPromptsourceMapper, JinjaPromptsourceMapper, PromptsourceMapper, ) +from smashed.recipes.promptsource import PromptsourceRecipe class TestPromptsource(unittest.TestCase): @@ -55,3 +58,36 @@ def test_dataset_prompt_source_mapper(self): mapper2 = PromptsourceMapper(mapper.template) mapped_dataset2 = mapper2.map(dataset, remove_columns=True) self.assertEqual(mapped_dataset, mapped_dataset2) + + def test_promptsource_recipe(self): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + + recipe = PromptsourceRecipe( + tokenizer=AutoTokenizer.from_pretrained("bert-base-cased"), + jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}", + max_source_content_length=15, + max_target_content_length=5, + ) + dataset = [ + { + "question": "What is the capital of France?", + "context": "Paris is the capital of " + ("France " * 10), + "answer": "Paris " * 10, + } + ] + + mapped_dataset, *_ = recipe.map(dataset) + + self.assertEqual( + tokenizer.decode(mapped_dataset["input_ids"]), + ( + "Q : What is the capital of France? " + "C : Paris is the capital of France " + "A :" + ), + ) + + self.assertEqual( + tokenizer.decode(mapped_dataset["labels"]), + "Paris Paris Paris Paris Paris", + )