Skip to content

Commit

Permalink
Added support for few-shot prompting, decoding of tokenized sequences (
Browse files Browse the repository at this point in the history
…#41)

* added few shot support

* style

* added decoder

* documentation

* added tests for decoder
  • Loading branch information
soldni authored Jan 9, 2023
1 parent dd98f7a commit 012a738
Show file tree
Hide file tree
Showing 12 changed files with 516 additions and 117 deletions.
4 changes: 2 additions & 2 deletions examples/qasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main():
pipeline = (
# concatenate the full text into a single string; use
# title_sep, para_sep, sec_sep, and abs_sep to manage separators
sm.JinjaPromptsourceMapper(
sm.JinjaMapper(
jinja=(
"{{title}}{{abs_sep}}"
"{{abstract}}{{abs_sep}}"
Expand Down Expand Up @@ -143,7 +143,7 @@ def main():
>> sm.WordsToTextMapper(
fields=["question", "context", "answers"],
)
>> sm.JinjaPromptsourceMapper(
>> sm.JinjaMapper(
jinja=(
"Q:{{question}}\nC:{{context}}\nA: "
"{% for answer in answers %}|||{{answer}}{% endfor %}"
Expand Down
2 changes: 1 addition & 1 deletion examples/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
>> sm.WordsToTextMapper(
fields=["question", "context", "answers"],
)
>> sm.JinjaPromptsourceMapper(
>> sm.JinjaMapper(
jinja=(
"Q:{{question}}\nC:{{context}}\nA: "
"{% for answer in answers %}|||{{answer}}{% endfor %}"
Expand Down
2 changes: 1 addition & 1 deletion examples/zero_shot_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(

self.max_generation_length = max_generation_length

self.recipe = smashed.recipes.PromptsourceRecipe(
self.recipe = smashed.recipes.JinjaRecipe(
tokenizer=self.tokenizer,
jinja_template=template,
max_source_content_length=max_source_content_length,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.14.0"
version = "0.15.0"
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "[email protected]" },
Expand Down
12 changes: 5 additions & 7 deletions src/smashed/mappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from .converters import Python2TorchMapper, Torch2PythonMapper
from .debug import DebugBatchedMapper, DebugSingleMapper
from .decoding import DecodingMapper
from .fields import (
ChangeFieldsMapper,
EnumerateFieldMapper,
Expand Down Expand Up @@ -43,11 +44,7 @@
FillTextPromptMapper,
TruncateMultipleFieldsMapper,
)
from .promptsource import (
DatasetPromptsourceMapper,
JinjaPromptsourceMapper,
PromptsourceMapper,
)
from .promptsource import FewShotJinjaMapper, JinjaMapper, PromptsourceMapper
from .shape import (
FlattenMapper,
SingleSequenceStriderMapper,
Expand All @@ -69,12 +66,13 @@
"CastMapper",
"ChangeFieldsMapper",
"CsvLoaderMapper",
"DatasetPromptsourceMapper",
"DecodingMapper",
"DebugBatchedMapper",
"DebugSingleMapper",
"EncodeFieldsMapper",
"EndCachingMapper",
"EnumerateFieldMapper",
"FewShotJinjaMapper",
"FillEncodedPromptMapper",
"FillTextPromptMapper",
"FilterMapper",
Expand All @@ -86,7 +84,7 @@
"GlomMapper",
"HuggingFaceDatasetLoaderMapper",
"IndicesToMaskMapper",
"JinjaPromptsourceMapper",
"JinjaMapper",
"JsonlLoaderMapper",
"LabelsMaskerMapper",
"ListCollatorMapper",
Expand Down
72 changes: 72 additions & 0 deletions src/smashed/mappers/decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Bunch of decoding mappers to reverse tokenization
@lucas
"""

from typing import Any, Dict, Optional, Sequence, Union

from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from ..base import SingleBaseMapper, TransformElementType

__all__ = ["DecodingMapper"]


class DecodingMapper(SingleBaseMapper):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
fields: Union[str, Sequence[str]],
decode_batch: bool = False,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
extra_decode_kwargs: Optional[Dict[str, Any]] = None,
):
"""A mapper that decodes one or more of tokenized sequences in
for the provided fields.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for
decoding; typically, this is the same tokenizer that was used
for tokenization.
fields (Union[str, Sequence[str]]): The fields to decode; could
either be a single field or a sequence of fields.
decode_batch (bool, optional): If True, it assume each sample is
a sequence of sequences to decode and will use the tokenizer's
`batch_decode` method. If False, it assume each sample contains
a single sequence to decode and will use the tokenizer's
`decode` method. Defaults to False.
skip_special_tokens (bool, optional): Whether to skip special
tokens (e.g., `[CLS]`, `</>`, etc) when decoding. Defaults to
False.
clean_up_tokenization_spaces (bool, optional): Whether to clean
up redundant spaces when decoding. Defaults to True.
extra_decode_kwargs (Optional[Dict[str, Any]], optional): Any
tokenizer-specific arguments to pass to the tokenizer's
`batch_decode` method. If not provided, no extra arguments
will be passed. Defaults to None.
"""

self.tokenizer = tokenizer
self.fields = [fields] if isinstance(fields, str) else fields
self.decode_batch = decode_batch
self.skip_special_tokens = skip_special_tokens
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.extra_decode_kwargs = extra_decode_kwargs or {}
super().__init__(input_fields=self.fields, output_fields=self.fields)

def transform(self, data: TransformElementType) -> TransformElementType:
return {
field: (
self.tokenizer.batch_decode
if self.decode_batch
else self.tokenizer.decode
)(
data[field],
skip_special_tokens=self.skip_special_tokens,
clean_up_tokenization_spaces=self.clean_up_tokenization_spaces,
**self.extra_decode_kwargs,
)
for field in self.fields
}
Loading

0 comments on commit 012a738

Please sign in to comment.