Skip to content

Commit

Permalink
Fixes for few shot libraries. (#42)
Browse files Browse the repository at this point in the history
* error handling; recipe can do fewshot

* bumped minor

* various bugfixes
  • Loading branch information
soldni authored Jan 9, 2023
1 parent 012a738 commit b3cc218
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.15.0"
version = "0.15.1"
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "[email protected]" },
Expand Down
19 changes: 18 additions & 1 deletion src/smashed/base/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,24 @@ def _index_fn(t: Tuple[str, Any]) -> int:

return transformed_batch

def one(self, **sample: TransformElementType) -> TransformElementType:
"""Transform a single sample. A convenience method that is
equivalent to self.map([sample])[0].
Args:
sample (TransformElementType): The sample to transform.
Returns:
TransformElementType: The transformed sample.
"""

out = self.map([sample])
if len(out) != 1:
raise ValueError(
f"Expected one sample, got {len(out)} samples instead."
)
return out[0]

@trouting
def map(self, dataset: Any, **map_kwargs: Any) -> Any:
"""Transform a dataset by applying this mapper's transform method.
Expand All @@ -168,7 +186,6 @@ def _map_list_of_dicts(
dataset: Sequence[TransformElementType],
**map_kwargs: Any,
) -> Sequence[TransformElementType]:

# explicitly casting to a boolean since this is all that is
# supported by the simple mapper.
# TODO[lucas]: maybe support specifying which fields to keep?
Expand Down
12 changes: 12 additions & 0 deletions src/smashed/base/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ class ChainableMapperMixIn(AbstractBaseMapper):
fingerprint: str
pipeline: Union["ChainableMapperMixIn", None]

def pip(self, n: Optional[int] = None):
"""Return the n-th mapper in the pipeline, or the next if
n is not provided. If n is 0, return this mapper."""
if n is None:
return self.pipeline
elif n == 0:
return self
elif self.pipeline is None:
raise IndexError("Pipeline index out of range")
else:
return self.pipeline.pip(n - 1)

def __init__(
self,
input_fields: Optional[Iterable[str]] = None,
Expand Down
47 changes: 29 additions & 18 deletions src/smashed/mappers/promptsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -39,6 +40,8 @@
"FewShotJinjaMapper",
]

VARSHOTS = "__shots__"


@Necessary(
"promptsource",
Expand Down Expand Up @@ -112,6 +115,10 @@ def __init__(
input_fields=input_fields, output_fields=output_fields
)

@staticmethod
def get_vars_from_txt(text: str) -> Set[str]:
return meta.find_undeclared_variables(Environment().parse(text))

@property
def approx_input_fields(self) -> Tuple[Set[str], ...]:
"""A tuple of sets of input fields that are required by the
Expand All @@ -129,9 +136,7 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]:
return tuple(
set(
field
for field in meta.find_undeclared_variables(
Environment().parse(t)
)
for field in self.get_vars_from_txt(t)
if field not in self.extra_vars
)
for t in self.template.jinja.split("|||")
Expand Down Expand Up @@ -350,7 +355,7 @@ class FewShotJinjaMapper(PromptsourceMixin, BatchedBaseMapper):
def __init__(
self,
jinja: str,
num_shots: int,
num_shots: Union[int, Literal["max"]],
name: Optional[str] = None,
reference: Optional[str] = None,
metadata: Optional["Template.Metadata"] = None,
Expand All @@ -374,7 +379,9 @@ def __init__(
are available as variables in the template. A special
variable __shots__ is available, which contains all the shots
for the sample.
num_shots (int): the number of shots to generate for each sample.
num_shots (Union[int, Literal['max']]): the number of samples to
use for each sample. If set to 'max', then all the samples
in the dataset are used.
name (Optional[str], optional): the name of the template. Defaults
to None.
reference (Optional[str], optional): the reference ID for the
Expand Down Expand Up @@ -404,15 +411,17 @@ def __init__(
of extra variables that will be passed to the promptsource
template. Defaults to None.
"""
if not isinstance(num_shots, int) and num_shots >= 0:
if num_shots != "max" and not (
isinstance(num_shots, int) and num_shots >= 0
):
raise ValueError(
"number_of_shots must be a non-negative integer, "
"number_of_shots must be a non-negative integer or 'max', "
f"but got {num_shots}"
)

if not re.search(r"\b__shots__\b", jinja):
raise ValueError(
"the jinja template must contain the variable __shots__"
if VARSHOTS not in self.get_vars_from_txt(jinja):
raise KeyError(
f"the jinja template must contain the variable {VARSHOTS}"
)

template = Template(
Expand All @@ -422,8 +431,12 @@ def __init__(
metadata=metadata,
)

self.num_shots = num_shots
self.keep_last = keep_last
# mypy complains if we don't retype num_shots
self.num_shots: Union[int, Literal["max"]] = num_shots

# due to how "max" works, we always need to keep the batch
# when in "max" mode, otherwise we will return an empty dataset
self.keep_last: bool = keep_last or num_shots == "max"

super().__init__(
template=template,
Expand All @@ -438,7 +451,7 @@ def __init__(
@property
def approx_input_fields(self) -> Tuple[Set[str], ...]:
return tuple(
set(f for f in fields if f != "__shots__")
set(f for f in fields if f != VARSHOTS)
for fields in super().approx_input_fields
)

Expand All @@ -449,12 +462,10 @@ def transform(
accumulator: List[TransformElementType] = []

for sample in data:
if len(accumulator) < self.num_shots:
if self.num_shots == "max" or len(accumulator) < self.num_shots:
accumulator.append(sample)
else:
output = self.apply_template(
{**sample, "__shots__": accumulator}
)
output = self.apply_template({**sample, VARSHOTS: accumulator})
accumulator = []
yield self.format_output(output)

Expand All @@ -465,5 +476,5 @@ def transform(
# use the last as the non-context sample
*accumulator, sample = accumulator

output = self.apply_template({**sample, "__shots__": accumulator})
output = self.apply_template({**sample, VARSHOTS: accumulator})
yield self.format_output(output)
17 changes: 14 additions & 3 deletions src/smashed/recipes/promptsource.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import reduce
from typing import Literal, Optional, Sequence, Set, cast
from typing import Literal, Optional, Sequence, Set, Union, cast

from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from ..base.recipes import BaseRecipe
from ..mappers.fields import ChangeFieldsMapper
from ..mappers.prompting import TruncateMultipleFieldsMapper
from ..mappers.promptsource import JinjaMapper
from ..mappers.promptsource import VARSHOTS, FewShotJinjaMapper, JinjaMapper
from ..mappers.text import TextToWordsMapper, WordsToTextMapper
from ..mappers.tokenize import TokenizerMapper

Expand All @@ -16,6 +16,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
jinja_template: str,
num_shots: Union[int, Literal["max"]] = 0,
max_source_content_length: Optional[int] = None,
max_target_content_length: Optional[int] = None,
truncation_strategy: Literal["longest", "uniform"] = "longest",
Expand Down Expand Up @@ -64,13 +65,23 @@ def __init__(
Defaults to None.
"""

# must run init before we start using `chain` function to add mappers
super().__init__()

# we instantiate the template mapper early on so we can get the text
# in the prompt that is not variable placeholders; however, we will
# wait till truncation mappers are added to the pipeline before
# instantiating the template mapper.
template_mapper = JinjaMapper(jinja=jinja_template)

# The mapper could be a FewShotJinjaMapper or a JinjaMapper, depending
# on whether it contains the variable `__shots__`.
template_mapper: Union[JinjaMapper, FewShotJinjaMapper]
if VARSHOTS in JinjaMapper.get_vars_from_txt(jinja_template):
template_mapper = FewShotJinjaMapper(
jinja=jinja_template, num_shots=num_shots
)
else:
template_mapper = JinjaMapper(jinja=jinja_template)

# if not provided, we try to infer the source and target fields
source_fields = list(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_promptsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,12 @@ def test_few_shot_jinja_zero_shots(self):
mapped_dataset[1]["target"],
"John Lennon was a musician.",
)

def test_few_shot_exception(self):
with self.assertRaises(KeyError):
FewShotJinjaMapper(
jinja="Q: {{question}}\nA: {{answer}}", num_shots=2
)

with self.assertRaises(ValueError):
FewShotJinjaMapper("{{ __shots__ }}", num_shots=-2)

0 comments on commit b3cc218

Please sign in to comment.