Skip to content

Commit

Permalink
Promptsource recipe (#40)
Browse files Browse the repository at this point in the history
* no-op for recipe

* no-op for recipe

* prompting recipe and test

* small error handling

* documentation

* documentation
  • Loading branch information
soldni authored Jan 6, 2023
1 parent 7e025ed commit 1061cf5
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.13.0"
version = "0.14.0"
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "[email protected]" },
Expand Down
155 changes: 146 additions & 9 deletions src/smashed/mappers/promptsource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, cast
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, cast

from necessary import Necessary, necessary

Expand Down Expand Up @@ -30,6 +31,35 @@ def __init__(
return_multiple_targets: bool = False,
extra_variables: Optional[Dict[str, Any]] = None,
):
"""Uses a promptsource template to generate source and target sequence;
in the returned dictionary of samples, the source sequence is stored
under the key `source_field_name` and the target sequence is stored
under the key `target_field_name`. If the template does not contain
the control sequence `|||`, then no target sequence is generated.
Args:
template (promptsource.templates.Template): the promptsource
template to use.
source_field_name (str, optional): the name of the field in the
returned dictionary of samples that will contain the source
sequence. Defaults to "source".
target_field_name (str, optional): the name of the field in the
returned dictionary of samples that will contain the target
sequence. Defaults to "target".
truncate (bool, optional): whether to truncate the source and
target sequences to the maximum length allowed by
the promptsource library. Defaults to False.
highlight_variables (bool, optional): whether to highlight the
variables in the source and target sequences with special
html tags. Defaults to False.
return_multiple_targets (bool, optional): whether to return
a list of target sequences for each sample. Defaults to False.
If the template returns multiple targets, but this argument
is set to False, then only the first target is returned.
extra_variables (Optional[Dict[str, Any]], optional): a dictionary
of extra variables that will be passed to the promptsource
template. Defaults to None.
"""

self.template = template
self.truncate = truncate
self.highlight_vars = highlight_variables
Expand All @@ -44,23 +74,65 @@ def __init__(

# abstract syntax tree for the jinja template; we will use it
# to find all fields that are required by the template
ast = Environment().parse(self.template.jinja)
input_fields = sorted(
var_name
for var_name in meta.find_undeclared_variables(ast)
if var_name not in self.extra_vars
)

output_fields = [self.src_fld_name]
if "|||" in self.template.jinja:
output_fields.append(self.tgt_fld_name)

input_src_fields, input_tgt_fields = self.approximate_input_fields
super().__init__(
input_fields=input_fields, output_fields=output_fields
input_fields=set(input_src_fields + input_tgt_fields),
output_fields=output_fields,
)

def _approximate_input_fields(self, jinja_txt: str) -> List[str]:
ast = Environment().parse(jinja_txt)
return sorted(
var_name
for var_name in meta.find_undeclared_variables(ast)
if var_name not in self.extra_vars
)

@property
def approximate_input_fields(self) -> Tuple[List[str], List[str]]:
"""Input fields that are likely to be required by the template;
It is approximate because we ignore nested variables."""

source_template, *target_templates = self.template.jinja.split("|||")
source_fields = self._approximate_input_fields(source_template)
target_fields = sorted(
set(
chain.from_iterable(
self._approximate_input_fields(template)
for template in target_templates
)
)
)
return source_fields, target_fields

def _approximate_text_from_template(self, txt: str) -> str:
return "".join(part.split("}}")[-1] for part in txt.split("{{"))

@property
def approximate_prompt_text(self) -> Tuple[str, List[str]]:
"""The prompt without the variables; it is approximate because
we might not be able to remove all variables."""

source_template, *target_templates = self.template.jinja.split("|||")

source_str = self._approximate_text_from_template(source_template)
target_str = [
self._approximate_text_from_template(template)
for template in target_templates
]
return source_str, target_str

@property
def has_target(self) -> bool:
return "|||" in self.template.jinja

def __getstate__(self) -> dict:
"""We need to serialize the template using yaml so the hash for this
"""We need to serialize thve template using yaml so the hash for this
mapper is consistent across runs."""
out = super().__getstate__()
out["__dict__"]["template"] = yaml.dump(self.template)
Expand Down Expand Up @@ -113,6 +185,37 @@ def __init__(
return_multiple_targets: bool = False,
extra_variables: Optional[Dict[str, Any]] = None,
):
"""Use one of the existing promptsource templates to generate
source and target sequences for a dataset. See the promptsource
repository for a list of available templates:
https://github.com/bigscience-workshop/promptsource
Args:
dataset_name (str): the name of the dataset to use.
template_name (str): the name of the template to use.
subset_name (Optional[str], optional): the name of the subset
to use. Defaults to None.
source_field_name (str, optional): the name of the field in the
returned dictionary of samples that will contain the source
sequence. Defaults to "source".
target_field_name (str, optional): the name of the field in the
returned dictionary of samples that will contain the target
sequence. Defaults to "target".
truncate (bool, optional): whether to truncate the source and
target sequences to the maximum length allowed by
the promptsource library. Defaults to False.
highlight_variables (bool, optional): whether to highlight the
variables in the source and target sequences with special
html tags. Defaults to False.
return_multiple_targets (bool, optional): whether to return
a list of target sequences for each sample. Defaults to False.
If the template returns multiple targets, but this argument
is set to False, then only the first target is returned.
extra_variables (Optional[Dict[str, Any]], optional): a dictionary
of extra variables that will be passed to the promptsource
template. Defaults to None.
"""

# DatasetTemplates is not well annotated, so though subset_name
# is optional, it is annotated as `str`, so we need to cast it.
subset_name = cast(str, subset_name)
Expand Down Expand Up @@ -151,6 +254,40 @@ def __init__(
return_multiple_targets: bool = False,
extra_variables: Optional[Dict[str, Any]] = None,
):
"""Use a custom jinja template to obtain a template from the
promptsource library. See the jinja documentation for a list of
language features and syntax: https://jinja.palletsprojects.com/
Args:
jinja (str): the jinja template to use. The template can access
the data in each sample; the name of fields in the datasets
are available as variables in the template.
name (Optional[str], optional): the name of the template. Defaults
to None.
reference (Optional[str], optional): the reference for the
template. Defaults to None.
metadata (Optional["Template.Metadata"], optional): the metadata
for the template. Defaults to None.
source_field_name (str, optional): the name of the field in the
returned dictionary of samples that will contain the source
sequence. Defaults to "source".
target_field_name (str, optional): the name of the field in the
returned dictionary of samples that will contain the target
sequence. Defaults to "target".
truncate (bool, optional): whether to truncate the source and
target sequences to the maximum length allowed by
the promptsource library. Defaults to False.
highlight_variables (bool, optional): whether to highlight the
variables in the source and target sequences with special
html tags. Defaults to False.
return_multiple_targets (bool, optional): whether to return
a list of target sequences for each sample. Defaults to False.
If the template returns multiple targets, but this argument
is set to False, then only the first target is returned.
extra_variables (Optional[Dict[str, Any]], optional): a dictionary
of extra variables that will be passed to the promptsource
template. Defaults to None.
"""
template = Template(
jinja=jinja,
name=(name or self.name),
Expand Down
2 changes: 2 additions & 0 deletions src/smashed/recipes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .collators import CollatorRecipe, SlowCollatorRecipe
from .prompting import PromptingRecipe
from .promptsource import PromptsourceRecipe

__all__ = [
"CollatorRecipe",
"PromptingRecipe",
"PromptsourceRecipe",
"SlowCollatorRecipe",
]
15 changes: 12 additions & 3 deletions src/smashed/recipes/collators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import torch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from ..base import BaseRecipe, SingleBaseMapper
Expand Down Expand Up @@ -46,9 +47,13 @@ def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, List[Any]]:

return collated_batch

def get_tensorizer(self) -> Python2TorchMapper:
def get_tensorizer(
self,
field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None,
device: Optional[Union[torch.device, str]] = None,
) -> Python2TorchMapper:
# this turns lists of ints/floats into tensors
return Python2TorchMapper()
return Python2TorchMapper(field_cast_map=field_cast_map, device=device)

def get_batcher(self, keep_last: bool) -> FixedBatchSizeMapper:
# the collator already receives the "right" number of samples
Expand All @@ -66,10 +71,14 @@ def __init__(
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
fields_pad_ids: Optional[Mapping[str, int]] = None,
unk_fields_pad_id: Optional[int] = None,
field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None,
device: Optional[Union[torch.device, str]] = None,
) -> None:
super().__init__(do_not_collate=do_not_collate)

self.chain(self.get_tensorizer())
self.chain(
self.get_tensorizer(field_cast_map=field_cast_map, device=device)
)
self.chain(self.get_batcher(keep_last=keep_last))

if tokenizer:
Expand Down
Loading

0 comments on commit 1061cf5

Please sign in to comment.