From 012a738683e7d351432d505f797a0caf01ca2c54 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Sun, 8 Jan 2023 16:46:19 -0800 Subject: [PATCH] Added support for few-shot prompting, decoding of tokenized sequences (#41) * added few shot support * style * added decoder * documentation * added tests for decoder --- examples/qasper.py | 4 +- examples/squad.py | 2 +- examples/zero_shot_prompting.py | 2 +- pyproject.toml | 2 +- src/smashed/mappers/__init__.py | 12 +- src/smashed/mappers/decoding.py | 72 +++++++ src/smashed/mappers/promptsource.py | 318 +++++++++++++++++++++------- src/smashed/recipes/__init__.py | 4 +- src/smashed/recipes/promptsource.py | 57 +++-- tests/test_decoding.py | 56 +++++ tests/test_hf_pickling.py | 4 +- tests/test_promptsource.py | 100 ++++++++- 12 files changed, 516 insertions(+), 117 deletions(-) create mode 100644 src/smashed/mappers/decoding.py create mode 100644 tests/test_decoding.py diff --git a/examples/qasper.py b/examples/qasper.py index 08b02c3..3bbe844 100644 --- a/examples/qasper.py +++ b/examples/qasper.py @@ -48,7 +48,7 @@ def main(): pipeline = ( # concatenate the full text into a single string; use # title_sep, para_sep, sec_sep, and abs_sep to manage separators - sm.JinjaPromptsourceMapper( + sm.JinjaMapper( jinja=( "{{title}}{{abs_sep}}" "{{abstract}}{{abs_sep}}" @@ -143,7 +143,7 @@ def main(): >> sm.WordsToTextMapper( fields=["question", "context", "answers"], ) - >> sm.JinjaPromptsourceMapper( + >> sm.JinjaMapper( jinja=( "Q:{{question}}\nC:{{context}}\nA: " "{% for answer in answers %}|||{{answer}}{% endfor %}" diff --git a/examples/squad.py b/examples/squad.py index 4f0a0ca..44e3ad9 100644 --- a/examples/squad.py +++ b/examples/squad.py @@ -31,7 +31,7 @@ >> sm.WordsToTextMapper( fields=["question", "context", "answers"], ) - >> sm.JinjaPromptsourceMapper( + >> sm.JinjaMapper( jinja=( "Q:{{question}}\nC:{{context}}\nA: " "{% for answer in answers %}|||{{answer}}{% endfor %}" diff --git a/examples/zero_shot_prompting.py b/examples/zero_shot_prompting.py index f191390..3abb76b 100644 --- a/examples/zero_shot_prompting.py +++ b/examples/zero_shot_prompting.py @@ -35,7 +35,7 @@ def __init__( self.max_generation_length = max_generation_length - self.recipe = smashed.recipes.PromptsourceRecipe( + self.recipe = smashed.recipes.JinjaRecipe( tokenizer=self.tokenizer, jinja_template=template, max_source_content_length=max_source_content_length, diff --git a/pyproject.toml b/pyproject.toml index be73182..d14bc80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.14.0" +version = "0.15.0" description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries." authors = [ {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" }, diff --git a/src/smashed/mappers/__init__.py b/src/smashed/mappers/__init__.py index f3db004..394e189 100644 --- a/src/smashed/mappers/__init__.py +++ b/src/smashed/mappers/__init__.py @@ -8,6 +8,7 @@ ) from .converters import Python2TorchMapper, Torch2PythonMapper from .debug import DebugBatchedMapper, DebugSingleMapper +from .decoding import DecodingMapper from .fields import ( ChangeFieldsMapper, EnumerateFieldMapper, @@ -43,11 +44,7 @@ FillTextPromptMapper, TruncateMultipleFieldsMapper, ) -from .promptsource import ( - DatasetPromptsourceMapper, - JinjaPromptsourceMapper, - PromptsourceMapper, -) +from .promptsource import FewShotJinjaMapper, JinjaMapper, PromptsourceMapper from .shape import ( FlattenMapper, SingleSequenceStriderMapper, @@ -69,12 +66,13 @@ "CastMapper", "ChangeFieldsMapper", "CsvLoaderMapper", - "DatasetPromptsourceMapper", + "DecodingMapper", "DebugBatchedMapper", "DebugSingleMapper", "EncodeFieldsMapper", "EndCachingMapper", "EnumerateFieldMapper", + "FewShotJinjaMapper", "FillEncodedPromptMapper", "FillTextPromptMapper", "FilterMapper", @@ -86,7 +84,7 @@ "GlomMapper", "HuggingFaceDatasetLoaderMapper", "IndicesToMaskMapper", - "JinjaPromptsourceMapper", + "JinjaMapper", "JsonlLoaderMapper", "LabelsMaskerMapper", "ListCollatorMapper", diff --git a/src/smashed/mappers/decoding.py b/src/smashed/mappers/decoding.py new file mode 100644 index 0000000..01ad66a --- /dev/null +++ b/src/smashed/mappers/decoding.py @@ -0,0 +1,72 @@ +""" +Bunch of decoding mappers to reverse tokenization + +@lucas +""" + +from typing import Any, Dict, Optional, Sequence, Union + +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from ..base import SingleBaseMapper, TransformElementType + +__all__ = ["DecodingMapper"] + + +class DecodingMapper(SingleBaseMapper): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + fields: Union[str, Sequence[str]], + decode_batch: bool = False, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + extra_decode_kwargs: Optional[Dict[str, Any]] = None, + ): + """A mapper that decodes one or more of tokenized sequences in + for the provided fields. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to use for + decoding; typically, this is the same tokenizer that was used + for tokenization. + fields (Union[str, Sequence[str]]): The fields to decode; could + either be a single field or a sequence of fields. + decode_batch (bool, optional): If True, it assume each sample is + a sequence of sequences to decode and will use the tokenizer's + `batch_decode` method. If False, it assume each sample contains + a single sequence to decode and will use the tokenizer's + `decode` method. Defaults to False. + skip_special_tokens (bool, optional): Whether to skip special + tokens (e.g., `[CLS]`, ``, etc) when decoding. Defaults to + False. + clean_up_tokenization_spaces (bool, optional): Whether to clean + up redundant spaces when decoding. Defaults to True. + extra_decode_kwargs (Optional[Dict[str, Any]], optional): Any + tokenizer-specific arguments to pass to the tokenizer's + `batch_decode` method. If not provided, no extra arguments + will be passed. Defaults to None. + """ + + self.tokenizer = tokenizer + self.fields = [fields] if isinstance(fields, str) else fields + self.decode_batch = decode_batch + self.skip_special_tokens = skip_special_tokens + self.clean_up_tokenization_spaces = clean_up_tokenization_spaces + self.extra_decode_kwargs = extra_decode_kwargs or {} + super().__init__(input_fields=self.fields, output_fields=self.fields) + + def transform(self, data: TransformElementType) -> TransformElementType: + return { + field: ( + self.tokenizer.batch_decode + if self.decode_batch + else self.tokenizer.decode + )( + data[field], + skip_special_tokens=self.skip_special_tokens, + clean_up_tokenization_spaces=self.clean_up_tokenization_spaces, + **self.extra_decode_kwargs, + ) + for field in self.fields + } diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 3c3a7a2..334a123 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -1,9 +1,26 @@ -from itertools import chain -from typing import Any, Dict, List, Optional, Tuple, cast +import re +from functools import reduce +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) from necessary import Necessary, necessary -from ..base import SingleBaseMapper, TransformElementType +from ..base.mappers import ( + BatchedBaseMapper, + ChainableMapperMixIn, + SingleBaseMapper, + TransformElementType, +) from ..utils import get_name_and_version with necessary("promptsource", soft=True) as PROMPTSOURCE_AVAILABLE: @@ -16,16 +33,23 @@ from jinja2 import Environment, meta +__all__ = [ + "PromptsourceMapper", + "JinjaMapper", + "FewShotJinjaMapper", +] + + @Necessary( "promptsource", message="{module_name} missing. Fix with 'pip install smashed[prompting]'", ) -class PromptsourceMapper(SingleBaseMapper): +class PromptsourceMixin(ChainableMapperMixIn): def __init__( self, template: "Template", - source_field_name: str = "source", - target_field_name: str = "target", + output_source_field_name: str = "source", + output_target_field_name: str = "target", truncate: bool = False, highlight_variables: bool = False, return_multiple_targets: bool = False, @@ -39,12 +63,12 @@ def __init__( Args: template (promptsource.templates.Template): the promptsource template to use. - source_field_name (str, optional): the name of the field in the - returned dictionary of samples that will contain the source - sequence. Defaults to "source". - target_field_name (str, optional): the name of the field in the - returned dictionary of samples that will contain the target - sequence. Defaults to "target". + output_source_field_name (str, optional): the name of the field + in the returned dictionary of samples that will contain the + source sequence. Defaults to "source". + output_target_field_name (str, optional): the name of the field + in the returned dictionary of samples that will contain the + target sequence. Defaults to "target". truncate (bool, optional): whether to truncate the source and target sequences to the maximum length allowed by the promptsource library. Defaults to False. @@ -63,72 +87,68 @@ def __init__( self.template = template self.truncate = truncate self.highlight_vars = highlight_variables - self.src_fld_name = source_field_name - self.tgt_fld_name = target_field_name - self.return_multi_tgt = return_multiple_targets - self.extra_vars = extra_variables or {} # override the id for the template because by default it uses # a randomly generated uuid which makes hashing impossible setattr(self.template, "id", 0) - # abstract syntax tree for the jinja template; we will use it - # to find all fields that are required by the template + self.src_fld_name = output_source_field_name + self.tgt_fld_name = output_target_field_name + self.return_multiple_targets = return_multiple_targets + self.extra_vars = extra_variables or {} + # merge all fields from source and targets portion of the template + input_fields: Set[str] = reduce( + lambda t, s: t.union(s), self.approx_input_fields, set() + ) + + # the output field only contains the target field if the template + # has a target portion. output_fields = [self.src_fld_name] if "|||" in self.template.jinja: output_fields.append(self.tgt_fld_name) - input_src_fields, input_tgt_fields = self.approximate_input_fields super().__init__( - input_fields=set(input_src_fields + input_tgt_fields), - output_fields=output_fields, - ) - - def _approximate_input_fields(self, jinja_txt: str) -> List[str]: - ast = Environment().parse(jinja_txt) - return sorted( - var_name - for var_name in meta.find_undeclared_variables(ast) - if var_name not in self.extra_vars + input_fields=input_fields, output_fields=output_fields ) @property - def approximate_input_fields(self) -> Tuple[List[str], List[str]]: - """Input fields that are likely to be required by the template; - It is approximate because we ignore nested variables.""" - - source_template, *target_templates = self.template.jinja.split("|||") - source_fields = self._approximate_input_fields(source_template) - target_fields = sorted( + def approx_input_fields(self) -> Tuple[Set[str], ...]: + """A tuple of sets of input fields that are required by the + template. + + The first set contains input fields that are + in the source part of the template (i.e. before the control + sequence `|||`); subsequent sets contain input fields that + are in the targets. + + This is a conservative estimate of the input fields required, + since we can't parse out cases where for loops or if statements + are used, nor cases where members of a variable are accessed. + """ + return tuple( set( - chain.from_iterable( - self._approximate_input_fields(template) - for template in target_templates + field + for field in meta.find_undeclared_variables( + Environment().parse(t) ) + if field not in self.extra_vars ) + for t in self.template.jinja.split("|||") ) - return source_fields, target_fields - - def _approximate_text_from_template(self, txt: str) -> str: - return "".join(part.split("}}")[-1] for part in txt.split("{{")) @property - def approximate_prompt_text(self) -> Tuple[str, List[str]]: - """The prompt without the variables; it is approximate because - we might not be able to remove all variables.""" - - source_template, *target_templates = self.template.jinja.split("|||") - - source_str = self._approximate_text_from_template(source_template) - target_str = [ - self._approximate_text_from_template(template) - for template in target_templates - ] - return source_str, target_str + def template_text(self) -> Tuple[str, ...]: + """The text of the template, with all variables and + control sequences removed.""" + return tuple( + re.sub(r"\{(%|\{|#).+?(#|%|\})\}", "", t) + for t in self.template.jinja.split("|||") + ) @property def has_target(self) -> bool: + """Whether the template has one or more target sequence.""" return "|||" in self.template.jinja def __getstate__(self) -> dict: @@ -146,33 +166,58 @@ def __setstate__(self, state: dict) -> None: state["__dict__"]["template"], Loader=yaml.FullLoader ) - def transform(self, data: TransformElementType) -> TransformElementType: + def apply_template(self, data: Dict[str, Any]) -> Sequence[str]: + """Given a dictionary of data, apply the template to generate + source sequence and target sequence(s).""" + if self.extra_vars: # add any extra variables to the data data = {**data, **self.extra_vars} - src, *tgt = self.template.apply( + return self.template.apply( data, truncate=self.truncate, highlight_variables=self.highlight_vars, ) - if self.return_multi_tgt: + + def format_output( + self, output: Sequence[str] + ) -> Dict[str, Union[str, List[str]]]: + """Given a list of source and target sequences, format the output + as a dictionary of samples; if `return_multiple_targets` is True, + then the target field will be a list of strings, otherwise it will + be a single string.""" + + # unpack for convenience; we will have to slice anyway later + src, *tgt = output + + if self.return_multiple_targets: + # ok to return multiple targets, so we return a list return {self.src_fld_name: src, self.tgt_fld_name: tgt} - elif len(tgt) == 0: + + if len(tgt) == 0: + # no target, so just return the source return {self.src_fld_name: src} - elif len(tgt) > 1: + + if len(tgt) > 1: + # we want to return a single target, but there are multiple! + # therefore, we raise an error. raise ValueError( "Multiple targets, but `return_multiple_targets` is False" ) - else: - return {self.src_fld_name: src, self.tgt_fld_name: tgt[0]} + return {self.src_fld_name: src, self.tgt_fld_name: tgt[0]} -@Necessary( - "promptsource", - message="{module_name} missing. Fix with 'pip install smashed[prompting]'", -) -class DatasetPromptsourceMapper(PromptsourceMapper): + +class SingleTransformPromptsourceMixin(PromptsourceMixin, SingleBaseMapper): + # We need this class pretty much just so that we can inherit from + # SingleBaseMapper. + def transform(self, data: TransformElementType) -> TransformElementType: + encoded = self.apply_template(data) # type: ignore + return self.format_output(encoded) + + +class PromptsourceMapper(SingleTransformPromptsourceMixin): def __init__( self, dataset_name: str, @@ -227,8 +272,8 @@ def __init__( super().__init__( template=template, - source_field_name=source_field_name, - target_field_name=target_field_name, + output_source_field_name=source_field_name, + output_target_field_name=target_field_name, truncate=truncate, highlight_variables=highlight_variables, return_multiple_targets=return_multiple_targets, @@ -236,11 +281,7 @@ def __init__( ) -@Necessary( - "promptsource", - message="{module_name} missing. Fix with 'pip install smashed[prompting]'", -) -class JinjaPromptsourceMapper(PromptsourceMapper): +class JinjaMapper(SingleTransformPromptsourceMixin): def __init__( self, jinja: str, @@ -290,16 +331,139 @@ def __init__( """ template = Template( jinja=jinja, - name=(name or self.name), + name=name, reference=(reference or get_name_and_version()), metadata=metadata, ) super().__init__( template=template, - source_field_name=source_field_name, - target_field_name=target_field_name, + output_source_field_name=source_field_name, + output_target_field_name=target_field_name, truncate=truncate, highlight_variables=highlight_variables, return_multiple_targets=return_multiple_targets, extra_variables=extra_variables, ) + + +class FewShotJinjaMapper(PromptsourceMixin, BatchedBaseMapper): + def __init__( + self, + jinja: str, + num_shots: int, + name: Optional[str] = None, + reference: Optional[str] = None, + metadata: Optional["Template.Metadata"] = None, + keep_last: bool = False, + output_source_field_name: str = "source", + output_target_field_name: str = "target", + truncate: bool = False, + highlight_variables: bool = False, + return_multiple_targets: bool = False, + extra_variables: Optional[Dict[str, Any]] = None, + ): + """Uses a jinja to generate source and target sequence; + in the returned dictionary of samples, the source sequence is stored + under the key `source_field_name` and the target sequence is stored + under the key `target_field_name`. If the template does not contain + the control sequence `|||`, then no target sequence is generated. + + Args: + jinja (str): the jinja template to use. The template can access + the data in each sample; the name of fields in the datasets + are available as variables in the template. A special + variable __shots__ is available, which contains all the shots + for the sample. + num_shots (int): the number of shots to generate for each sample. + name (Optional[str], optional): the name of the template. Defaults + to None. + reference (Optional[str], optional): the reference ID for the + template. Defaults to None. + metadata (Optional["Template.Metadata"], optional): the metadata + for the template. Defaults to None. + keep_last (bool, optional): whether to keep the last shot in the + dataset if we have leftover samples less than the number of + shots. Defaults to False. + output_source_field_name (str, optional): the name of the field + in the returned dictionary of samples that will contain the + source sequence. Defaults to "source". + output_target_field_name (str, optional): the name of the field + in the returned dictionary of samples that will contain the + target sequence. Defaults to "target". + truncate (bool, optional): whether to truncate the source and + target sequences to the maximum length allowed by + the promptsource library. Defaults to False. + highlight_variables (bool, optional): whether to highlight the + variables in the source and target sequences with special + html tags. Defaults to False. + return_multiple_targets (bool, optional): whether to return + a list of target sequences for each sample. Defaults to False. + If the template returns multiple targets, but this argument + is set to False, then only the first target is returned. + extra_variables (Optional[Dict[str, Any]], optional): a dictionary + of extra variables that will be passed to the promptsource + template. Defaults to None. + """ + if not isinstance(num_shots, int) and num_shots >= 0: + raise ValueError( + "number_of_shots must be a non-negative integer, " + f"but got {num_shots}" + ) + + if not re.search(r"\b__shots__\b", jinja): + raise ValueError( + "the jinja template must contain the variable __shots__" + ) + + template = Template( + jinja=jinja, + name=name, + reference=(reference or get_name_and_version()), + metadata=metadata, + ) + + self.num_shots = num_shots + self.keep_last = keep_last + + super().__init__( + template=template, + output_source_field_name=output_source_field_name, + output_target_field_name=output_target_field_name, + truncate=truncate, + highlight_variables=highlight_variables, + return_multiple_targets=return_multiple_targets, + extra_variables=extra_variables, + ) + + @property + def approx_input_fields(self) -> Tuple[Set[str], ...]: + return tuple( + set(f for f in fields if f != "__shots__") + for fields in super().approx_input_fields + ) + + def transform( + self, data: Iterable[TransformElementType] + ) -> Iterable[TransformElementType]: + + accumulator: List[TransformElementType] = [] + + for sample in data: + if len(accumulator) < self.num_shots: + accumulator.append(sample) + else: + output = self.apply_template( + {**sample, "__shots__": accumulator} + ) + accumulator = [] + yield self.format_output(output) + + if self.keep_last and len(accumulator) > 0: + # we yield the last bit of the dataset; might have + # fewer than self.num_shots samples + + # use the last as the non-context sample + *accumulator, sample = accumulator + + output = self.apply_template({**sample, "__shots__": accumulator}) + yield self.format_output(output) diff --git a/src/smashed/recipes/__init__.py b/src/smashed/recipes/__init__.py index 34d8226..4448c2b 100644 --- a/src/smashed/recipes/__init__.py +++ b/src/smashed/recipes/__init__.py @@ -1,10 +1,10 @@ from .collators import CollatorRecipe, SlowCollatorRecipe from .prompting import PromptingRecipe -from .promptsource import PromptsourceRecipe +from .promptsource import JinjaRecipe __all__ = [ "CollatorRecipe", "PromptingRecipe", - "PromptsourceRecipe", + "JinjaRecipe", "SlowCollatorRecipe", ] diff --git a/src/smashed/recipes/promptsource.py b/src/smashed/recipes/promptsource.py index e9ab8de..ee54f06 100644 --- a/src/smashed/recipes/promptsource.py +++ b/src/smashed/recipes/promptsource.py @@ -1,16 +1,17 @@ -from typing import Literal, Optional, Sequence +from functools import reduce +from typing import Literal, Optional, Sequence, Set, cast from transformers.tokenization_utils_base import PreTrainedTokenizerBase from ..base.recipes import BaseRecipe from ..mappers.fields import ChangeFieldsMapper from ..mappers.prompting import TruncateMultipleFieldsMapper -from ..mappers.promptsource import JinjaPromptsourceMapper +from ..mappers.promptsource import JinjaMapper from ..mappers.text import TextToWordsMapper, WordsToTextMapper from ..mappers.tokenize import TokenizerMapper -class PromptsourceRecipe(BaseRecipe): +class JinjaRecipe(BaseRecipe): def __init__( self, tokenizer: PreTrainedTokenizerBase, @@ -19,6 +20,8 @@ def __init__( max_target_content_length: Optional[int] = None, truncation_strategy: Literal["longest", "uniform"] = "longest", use_words: bool = True, + source_fields: Optional[Sequence[str]] = None, + target_fields: Optional[Sequence[str]] = None, additional_fields_to_keep: Optional[Sequence[str]] = None, ) -> None: """A recipe for a pipeline that uses promptsource to format data @@ -48,6 +51,12 @@ def __init__( use_words (bool, optional): When truncating, whether to use count of words or count of characters. Defaults to True, which means that we use count of words. + source_fields (Optional[Sequence[str]], optional): The fields in + the template that are the source. If not provided, we will + try to infer them from the template. Defaults to None. + target_fields (Optional[Sequence[str]], optional): The fields in + the template that are the target. If not provided, we will + try to infer them from the template. Defaults to None. additional_fields_to_keep (Optional[Sequence[str]], optional): After the recipe has been applied, we drop all columns that are not 'input_ids', 'attention_mask', or 'labels'. If you @@ -61,35 +70,50 @@ def __init__( # in the prompt that is not variable placeholders; however, we will # wait till truncation mappers are added to the pipeline before # instantiating the template mapper. - template_mapper = JinjaPromptsourceMapper(jinja=jinja_template) - src_fields, tgt_fields = template_mapper.approximate_input_fields - src_text, tgt_text = template_mapper.approximate_prompt_text + template_mapper = JinjaMapper(jinja=jinja_template) + + # if not provided, we try to infer the source and target fields + source_fields = list( + source_fields or template_mapper.approx_input_fields[0] + ) + target_fields = list( + target_fields + or reduce( + lambda t, s: t.union(s), + template_mapper.approx_input_fields[1:], + cast(Set[str], set()), # cast necessary for mypy + ) + ) + + # we get the text used in the prompt for source and target + # that is not a variable placeholder or control sequence . + source_text, *target_text = template_mapper.template_text if use_words: # if we we need to first set up a text -> words splitter for # the fields in the template text_to_words = TextToWordsMapper( - fields=list(set(src_fields + tgt_fields)) + fields=source_fields + target_fields ) self.chain(text_to_words) # we also need to calculate the lengths in words of the part of # the prompt that is not content; that way we can subtract it # from the max content length, for both source and target. - length_src_prompt = len(text_to_words.splitter(src_text)) + length_src_prompt = len(text_to_words.splitter(source_text)) # for target, we actually take the max in case there are multiple # prompt versions. length_tgt_prompt = max( - [len(text_to_words.splitter(t)) for t in tgt_text] + [len(text_to_words.splitter(t)) for t in target_text] # in case tgt_text is empty, we use 0 as a default value or [0] ) else: # if we don't use words, we just use the length of the prompt # in characters. - length_src_prompt = len(src_text) - length_tgt_prompt = len(tgt_text) + length_src_prompt = len(source_text) + length_tgt_prompt = len(target_text) if max_source_content_length is not None: # in case a max length for the source is provided, we need to @@ -108,16 +132,17 @@ def __init__( # finally we add a mapper that truncates the source fields. self.chain( TruncateMultipleFieldsMapper( - fields_to_truncate=src_fields, + fields_to_truncate=source_fields, max_length=max_source_content_length, strategy=truncation_strategy, ) ) - if tgt_text and max_target_content_length: + if len(target_text) > 0 and max_target_content_length: # we operate here in the same way as for the source, but we # only do it if there is a target prompt. max_target_content_length -= length_tgt_prompt + if max_target_content_length < 1: raise ValueError( f"max_target_content_length must be at least equal to " @@ -126,7 +151,7 @@ def __init__( self.chain( TruncateMultipleFieldsMapper( - fields_to_truncate=tgt_fields, + fields_to_truncate=target_fields, max_length=max_target_content_length, strategy=truncation_strategy, ) @@ -135,9 +160,7 @@ def __init__( if use_words: # if we used words, we need to convert the fields back to text # before filling the template. - self.chain( - WordsToTextMapper(fields=list(set(src_fields + tgt_fields))) - ) + self.chain(WordsToTextMapper(fields=source_fields + target_fields)) # we only add the template here because we first need to truncate # the fields! diff --git a/tests/test_decoding.py b/tests/test_decoding.py new file mode 100644 index 0000000..788f2a9 --- /dev/null +++ b/tests/test_decoding.py @@ -0,0 +1,56 @@ +import unittest + +from transformers.models.auto import AutoTokenizer + +from smashed.mappers.decoding import DecodingMapper +from smashed.mappers.tokenize import TokenizerMapper + + +class TestDecoding(unittest.TestCase): + def setUp(self) -> None: + self.bert_tok = AutoTokenizer.from_pretrained("bert-base-cased") + self.gpt2_tok = AutoTokenizer.from_pretrained("gpt2") + + def test_decoding_mapper(self): + dataset = [ + { + "source": "Translate english to french : this is a test", + "target": "c'est un test", + }, + { + "source": "Translate english to german : this is another test", + "target": "Das ist ein anderer test", + }, + { + "source": "Translate english to italian : tests are important", + "target": "I test sono importanti", + }, + ] + + for tokenizer in [self.bert_tok, self.gpt2_tok]: + mapper = ( + TokenizerMapper( + tokenizer=tokenizer, + input_field="source", + add_special_tokens=False, + return_attention_mask=False, + output_rename_map={"input_ids": "source"}, + ) + >> TokenizerMapper( + tokenizer=tokenizer, + input_field="target", + add_special_tokens=False, + return_attention_mask=False, + output_rename_map={"input_ids": "target"}, + ) + >> DecodingMapper( + tokenizer=tokenizer, + fields=["source", "target"], + ) + ) + + mapped_dataset = mapper.map(dataset) + + for i, d in enumerate(mapped_dataset): + self.assertEqual(d["source"], dataset[i]["source"]) + self.assertEqual(d["target"], dataset[i]["target"]) diff --git a/tests/test_hf_pickling.py b/tests/test_hf_pickling.py index d261a13..ef7099a 100644 --- a/tests/test_hf_pickling.py +++ b/tests/test_hf_pickling.py @@ -7,7 +7,7 @@ from smashed.contrib.squad import ConcatenateContextMapper from smashed.mappers import ( EnumerateFieldMapper, - JinjaPromptsourceMapper, + JinjaMapper, TokenizerMapper, TruncateMultipleFieldsMapper, UnpackingMapper, @@ -148,7 +148,7 @@ def test_enumerate(self): ) def test_promptsource(self): - mp = JinjaPromptsourceMapper(jinja="hello {{world}}") + mp = JinjaMapper(jinja="hello {{world}}") dataset = Dataset.from_dict( {"world": [uuid4().hex for _ in range(20)] * 2} diff --git a/tests/test_promptsource.py b/tests/test_promptsource.py index ab6d1ba..0df5ae5 100644 --- a/tests/test_promptsource.py +++ b/tests/test_promptsource.py @@ -3,16 +3,17 @@ from transformers.models.auto import AutoTokenizer from smashed.mappers.promptsource import ( - DatasetPromptsourceMapper, - JinjaPromptsourceMapper, + FewShotJinjaMapper, + JinjaMapper, PromptsourceMapper, + SingleTransformPromptsourceMixin, ) -from smashed.recipes.promptsource import PromptsourceRecipe +from smashed.recipes.promptsource import JinjaRecipe class TestPromptsource(unittest.TestCase): def test_jinja_prompt_source_mapper(self): - mapper = JinjaPromptsourceMapper( + mapper = JinjaMapper( jinja="Q: {{question}}\nA: |||{{answers.text[0]}}" ) dataset = [ @@ -30,7 +31,7 @@ def test_jinja_prompt_source_mapper(self): self.assertEqual(mapped_dataset[0]["target"], "Paris") def test_dataset_prompt_source_mapper(self): - mapper = DatasetPromptsourceMapper( + mapper = PromptsourceMapper( dataset_name="squad", template_name="given_context_answer_question_variation", ) @@ -55,14 +56,14 @@ def test_dataset_prompt_source_mapper(self): ) self.assertEqual(mapped_dataset[0]["target"], "Paris") - mapper2 = PromptsourceMapper(mapper.template) + mapper2 = SingleTransformPromptsourceMixin(mapper.template) mapped_dataset2 = mapper2.map(dataset, remove_columns=True) self.assertEqual(mapped_dataset, mapped_dataset2) def test_promptsource_recipe(self): tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") - recipe = PromptsourceRecipe( + recipe = JinjaRecipe( tokenizer=AutoTokenizer.from_pretrained("bert-base-cased"), jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}", max_source_content_length=15, @@ -91,3 +92,88 @@ def test_promptsource_recipe(self): tokenizer.decode(mapped_dataset["labels"]), "Paris Paris Paris Paris Paris", ) + + def _few_shot_data_prompt(self): + dataset = [ + { + "question": "Who is Bill Gates?", + "answer": "Bill Gates is a billionaire.", + }, + { + "question": "who is john lennon?", + "answer": "John Lennon was a musician.", + }, + { + "question": "who is john doe?", + "answer": "John Doe is a fictional character.", + }, + { + "question": "who is goldie hawn?", + "answer": "Goldie Hawn is an actress.", + }, + { + "question": "who is ru paul?", + "answer": "Ru Paul is a drag queen.", + }, + ] + jinja_prompt = ( + "{% for shot in __shots__ %}" + "Q: {{shot.question}}\n" + "A: {{shot.answer}}\n" + "\n" + "{% endfor %}" + "Q: {{question}}\n" + "A: |||{{answer}}" + ) + + return dataset, jinja_prompt + + def test_fewshot_jinja(self): + + dataset, jinja_prompt = self._few_shot_data_prompt() + + mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=2) + + mapped_dataset = mapper.map(dataset) + + self.assertEqual(len(mapped_dataset), 1) + + self.assertEqual( + mapped_dataset[0]["source"], + ( + "Q: Who is Bill Gates?\nA: Bill Gates is a billionaire.\n\n" + "Q: who is john lennon?\nA: John Lennon was a musician.\n\n" + "Q: who is john doe?\nA: " + ), + ) + + self.assertEqual( + mapped_dataset[0]["target"], + "John Doe is a fictional character.", + ) + + def test_few_shot_jinja_zero_shots(self): + dataset, jinja_prompt = self._few_shot_data_prompt() + + mapper = FewShotJinjaMapper(jinja=jinja_prompt, num_shots=0) + + mapped_dataset = mapper.map(dataset) + + self.assertEqual(len(mapped_dataset), 5) + + self.assertEqual( + mapped_dataset[0]["source"], "Q: Who is Bill Gates?\nA: " + ) + + self.assertEqual( + mapped_dataset[0]["target"], + "Bill Gates is a billionaire.", + ) + + self.assertEqual( + mapped_dataset[1]["source"], "Q: who is john lennon?\nA: " + ) + self.assertEqual( + mapped_dataset[1]["target"], + "John Lennon was a musician.", + )