diff --git a/pyproject.toml b/pyproject.toml index d890e26..533eead 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,19 @@ [project] name = "smashed" - version = "0.15.4" -description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries." -authors = [ - {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" }, - {name = "Luca Soldaini", email = "luca@soldaini.net"} -] +version = "0.15.5" +description = """\ +SMASHED is a toolkit designed to apply transformations to samples in \ +datasets, such as fields extraction, tokenization, prompting, batching, \ +and more. Supports datasets from Huggingface, torchdata iterables, or \ +simple lists of dictionaries.\ +""" +# authors = [ +# {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"}, +# {name = "Luca Soldaini", email = "luca@soldaini.net"} +# ] license = {text = "Apache-2.0"} readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "torch>=1.9", "transformers>=4.5", @@ -17,6 +22,7 @@ dependencies = [ "ftfy>=6.1.1", "platformdirs>=2.5.0", "glom>=21.0.0", + "Jinja2>=3.0.3", ] classifiers = [ "Development Status :: 4 - Beta", @@ -32,20 +38,35 @@ keywords = [ "mappers", "pytorch", "torch", - "huggingfae", + "huggingface", "transformers", "datasets", "dict", - "datset", "pipeline", "preprocessing", "nlp", "natural language processing", "text", - "prompting" + "prompting", + "prefix tuning", + "in context learning" ] +[[project.authors]] +name = "Allen Institute for Artificial Intelligence" +email = "contact@allenai.org" + +[[project.authors]] +name = "Luca Soldaini" +email = "luca@soldaini.net" + +[[project.authors]] +name = "Kyle Lo" +email = "kylel@allenai.org" +[[project.maintainers]] +name = "Luca Soldaini" +email = "luca@soldaini.net" [project.urls] "Homepage" = "https://github.com/allenai/smashed" @@ -70,7 +91,7 @@ requires = [ [project.optional-dependencies] dev = [ - "springs>=1.8.3", + "springs>=1.9.1", "black[jupyter]>=21.12b0", "isort>=5.8.0", "mypy>=0.971", @@ -87,17 +108,16 @@ remote = [ "boto3>=1.25.5", ] datasets = [ - "datasets>=2.4.0", + "datasets>=2.8.0", "dill>=0.3.0", ] prompting = [ "promptsource>=0.2.3", "blingfire>=0.1.8", - "PyYAML>=6.0.0", ] torchdata = [ - "torch>=1.12.1", - "torchdata>=0.4.1" + "torch>=1.13.1", + "torchdata>=0.5.1" ] all = [ "smashed[dev]", diff --git a/scripts/push-to-pypi.sh b/release/push-to-pypi.sh similarity index 100% rename from scripts/push-to-pypi.sh rename to release/push-to-pypi.sh diff --git a/src/smashed/base/mappers.py b/src/smashed/base/mappers.py index 06735fb..114e63e 100644 --- a/src/smashed/base/mappers.py +++ b/src/smashed/base/mappers.py @@ -24,17 +24,15 @@ class ChainableMapperMixIn(AbstractBaseMapper): fingerprint: str pipeline: Union["ChainableMapperMixIn", None] - def pip(self, n: Optional[int] = None): + def __getitem__(self, n: int) -> "ChainableMapperMixIn": """Return the n-th mapper in the pipeline, or the next if n is not provided. If n is 0, return this mapper.""" - if n is None: - return self.pipeline - elif n == 0: + if n == 0: return self elif self.pipeline is None: raise IndexError("Pipeline index out of range") else: - return self.pipeline.pip(n - 1) + return self.pipeline[n - 1] def __init__( self, diff --git a/src/smashed/base/pipeline.py b/src/smashed/base/pipeline.py index 1fe8672..826fa35 100644 --- a/src/smashed/base/pipeline.py +++ b/src/smashed/base/pipeline.py @@ -11,5 +11,5 @@ def make_pipeline( ) -> M: """Make a pipeline of mappers.""" for mapper in rest_mappers: - first_mapper = first_mapper.chain(mapper) + first_mapper.chain(mapper) return first_mapper diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 85c79a8..163ae5b 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -1,5 +1,5 @@ import re -from functools import reduce +from functools import cached_property, reduce from typing import ( Any, Dict, @@ -10,10 +10,12 @@ Sequence, Set, Tuple, + Type, Union, cast, ) +from jinja2 import BaseLoader, Environment, Template, meta from necessary import Necessary, necessary from ..base.mappers import ( @@ -22,16 +24,11 @@ SingleBaseMapper, TransformElementType, ) -from ..utils import get_name_and_version with necessary("promptsource", soft=True) as PROMPTSOURCE_AVAILABLE: if PROMPTSOURCE_AVAILABLE: - import yaml - from promptsource.templates import DatasetTemplates, Template - -with necessary("jinja2", soft=True) as JINJA_AVAILABLE: - if JINJA_AVAILABLE: - from jinja2 import Environment, meta + from promptsource.templates import DatasetTemplates + from promptsource.templates import Template as PromptsourceTemplate __all__ = [ @@ -40,21 +37,45 @@ "FewShotJinjaMapper", ] + VARSHOTS = "__shots__" +PIPE_ESCAPE = "3ed2dface8203c4c9dfb1a5dc58e41e0" + + +class JinjaEnvironment: + """A singleton for the jinja environment.""" + + _env: Optional["Environment"] = None + + @classmethod + def env(cls, loader: Optional[Type["BaseLoader"]] = None) -> "Environment": + if cls._env is not None: + return cls._env + + cls._env = Environment( + loader=(loader or BaseLoader) # pyright: ignore + ) + return cls._env + + @classmethod + def from_string( + cls, template: str, env_kwargs: Optional[dict] = None + ) -> "Template": + return cls.env(**(env_kwargs or {})).from_string(template) + + @classmethod + def find_undeclared_variables(cls, template: str) -> Set[str]: + """Find undeclared variables in a jinja template.""" + ast = cls.env().parse(template) + return meta.find_undeclared_variables(ast) -@Necessary( - "promptsource", - message="{module_name} missing. Fix with 'pip install smashed[prompting]'", -) class PromptsourceMixin(ChainableMapperMixIn): def __init__( self, - template: "Template", + template: str, output_source_field_name: str = "source", output_target_field_name: str = "target", - truncate: bool = False, - highlight_variables: bool = False, return_multiple_targets: bool = False, extra_variables: Optional[Dict[str, Any]] = None, ): @@ -63,6 +84,7 @@ def __init__( under the key `source_field_name` and the target sequence is stored under the key `target_field_name`. If the template does not contain the control sequence `|||`, then no target sequence is generated. + Args: template (promptsource.templates.Template): the promptsource template to use. @@ -72,12 +94,6 @@ def __init__( output_target_field_name (str, optional): the name of the field in the returned dictionary of samples that will contain the target sequence. Defaults to "target". - truncate (bool, optional): whether to truncate the source and - target sequences to the maximum length allowed by - the promptsource library. Defaults to False. - highlight_variables (bool, optional): whether to highlight the - variables in the source and target sequences with special - html tags. Defaults to False. return_multiple_targets (bool, optional): whether to return a list of target sequences for each sample. Defaults to False. If the template returns multiple targets, but this argument @@ -86,14 +102,7 @@ def __init__( of extra variables that will be passed to the promptsource template. Defaults to None. """ - self.template = template - self.truncate = truncate - self.highlight_vars = highlight_variables - - # override the id for the template because by default it uses - # a randomly generated uuid which makes hashing impossible - setattr(self.template, "id", 0) self.src_fld_name = output_source_field_name self.tgt_fld_name = output_target_field_name @@ -108,7 +117,7 @@ def __init__( # the output field only contains the target field if the template # has a target portion. output_fields = [self.src_fld_name] - if "|||" in self.template.jinja: + if "|||" in self.template: output_fields.append(self.tgt_fld_name) super().__init__( @@ -117,7 +126,7 @@ def __init__( @staticmethod def get_vars_from_txt(text: str) -> Set[str]: - return meta.find_undeclared_variables(Environment().parse(text)) + return JinjaEnvironment.find_undeclared_variables(text) @property def approx_input_fields(self) -> Tuple[Set[str], ...]: @@ -139,51 +148,51 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: for field in self.get_vars_from_txt(t) if field not in self.extra_vars ) - for t in self.template.jinja.split("|||") + for t in self.template.split("|||") ) @property def template_text(self) -> Tuple[str, ...]: """The text of the template, with all variables and control sequences removed.""" - return tuple( + fragments = tuple( re.sub(r"\{(%|\{|#).+?(#|%|\})\}", "", t) - for t in self.template.jinja.split("|||") + for t in self.template.split("|||") ) + return fragments @property def has_target(self) -> bool: """Whether the template has one or more target sequence.""" - return "|||" in self.template.jinja + return "|||" in self.template def __getstate__(self) -> dict: - """We need to serialize thve template using yaml so the hash for this - mapper is consistent across runs.""" - out = super().__getstate__() - out["__dict__"]["template"] = yaml.dump(self.template) - return out - - def __setstate__(self, state: dict) -> None: - """Because we serialized the template as yaml, we need to - deserialize before we can use it.""" - super().__setstate__(state) - self.template = yaml.load( - state["__dict__"]["template"], Loader=yaml.FullLoader + """We need to override this method so that the cached property + `_rendered_template` is not pickled. This is because the jinja + environment is not picklable, and the rendered template is + connected to the environment.""" + state = super().__getstate__() + state["__dict__"].pop("_rendered_template", None) + return state + + @cached_property + def _rendered_template(self) -> "Template": + return JinjaEnvironment.from_string( + self.template.replace("|||", PIPE_ESCAPE) ) + def _apply_template(self, data: Dict[str, Any]) -> Sequence[str]: + """Split a string on the pipe escape sequence.""" + content = self._rendered_template.render(data) + return tuple(t.strip() for t in content.split(PIPE_ESCAPE)) + def apply_template(self, data: Dict[str, Any]) -> Sequence[str]: """Given a dictionary of data, apply the template to generate source sequence and target sequence(s).""" - if self.extra_vars: # add any extra variables to the data data = {**data, **self.extra_vars} - - return self.template.apply( - data, - truncate=self.truncate, - highlight_variables=self.highlight_vars, - ) + return self._apply_template(data) def format_output( self, output: Sequence[str] @@ -222,16 +231,20 @@ def transform(self, data: TransformElementType) -> TransformElementType: return self.format_output(encoded) +@Necessary( + "promptsource", + message="{module_name} missing. Fix with 'pip install smashed[prompting]'", +) class PromptsourceMapper(SingleTransformPromptsourceMixin): def __init__( self, dataset_name: str, template_name: str, subset_name: Optional[str] = None, - source_field_name: str = "source", - target_field_name: str = "target", truncate: bool = False, highlight_variables: bool = False, + source_field_name: str = "source", + target_field_name: str = "target", return_multiple_targets: bool = False, extra_variables: Optional[Dict[str, Any]] = None, ): @@ -245,18 +258,18 @@ def __init__( template_name (str): the name of the template to use. subset_name (Optional[str], optional): the name of the subset to use. Defaults to None. - source_field_name (str, optional): the name of the field in the - returned dictionary of samples that will contain the source - sequence. Defaults to "source". - target_field_name (str, optional): the name of the field in the - returned dictionary of samples that will contain the target - sequence. Defaults to "target". truncate (bool, optional): whether to truncate the source and target sequences to the maximum length allowed by the promptsource library. Defaults to False. highlight_variables (bool, optional): whether to highlight the variables in the source and target sequences with special html tags. Defaults to False. + source_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the source + sequence. Defaults to "source". + target_field_name (str, optional): the name of the field in the + returned dictionary of samples that will contain the target + sequence. Defaults to "target". return_multiple_targets (bool, optional): whether to return a list of target sequences for each sample. Defaults to False. If the template returns multiple targets, but this argument @@ -265,38 +278,43 @@ def __init__( of extra variables that will be passed to the promptsource template. Defaults to None. """ - - # DatasetTemplates is not well annotated, so though subset_name - # is optional, it is annotated as `str`, so we need to cast it. - subset_name = cast(str, subset_name) - - template = DatasetTemplates( - dataset_name=dataset_name, - subset_name=subset_name, - )[template_name] + self.truncate = truncate + self.highlight_variables = highlight_variables + self.dataset_name = dataset_name + self.template_name = template_name + self.subset_name = subset_name super().__init__( - template=template, + template=self._rendered_template.jinja, output_source_field_name=source_field_name, output_target_field_name=target_field_name, - truncate=truncate, - highlight_variables=highlight_variables, return_multiple_targets=return_multiple_targets, extra_variables=extra_variables, ) + @cached_property + def _rendered_template(self) -> "PromptsourceTemplate": + # the type: ignore is because the promptsource library is not + # very well typed, so, even though subset_name should + return DatasetTemplates( + dataset_name=self.dataset_name, + subset_name=cast(str, self.subset_name), + )[self.template_name] + + def _apply_template(self, data: Dict[str, Any]) -> Sequence[str]: + return self._rendered_template.apply( + example=data, + truncate=self.truncate, + highlight_variables=self.highlight_variables, + ) + class JinjaMapper(SingleTransformPromptsourceMixin): def __init__( self, jinja: str, - name: Optional[str] = None, - reference: Optional[str] = None, - metadata: Optional["Template.Metadata"] = None, source_field_name: str = "source", target_field_name: str = "target", - truncate: bool = False, - highlight_variables: bool = False, return_multiple_targets: bool = False, extra_variables: Optional[Dict[str, Any]] = None, ): @@ -312,20 +330,12 @@ def __init__( to None. reference (Optional[str], optional): the reference for the template. Defaults to None. - metadata (Optional["Template.Metadata"], optional): the metadata - for the template. Defaults to None. source_field_name (str, optional): the name of the field in the returned dictionary of samples that will contain the source sequence. Defaults to "source". target_field_name (str, optional): the name of the field in the returned dictionary of samples that will contain the target sequence. Defaults to "target". - truncate (bool, optional): whether to truncate the source and - target sequences to the maximum length allowed by - the promptsource library. Defaults to False. - highlight_variables (bool, optional): whether to highlight the - variables in the source and target sequences with special - html tags. Defaults to False. return_multiple_targets (bool, optional): whether to return a list of target sequences for each sample. Defaults to False. If the template returns multiple targets, but this argument @@ -334,18 +344,10 @@ def __init__( of extra variables that will be passed to the promptsource template. Defaults to None. """ - template = Template( - jinja=jinja, - name=name, - reference=(reference or get_name_and_version()), - metadata=metadata, - ) super().__init__( - template=template, + template=jinja, output_source_field_name=source_field_name, output_target_field_name=target_field_name, - truncate=truncate, - highlight_variables=highlight_variables, return_multiple_targets=return_multiple_targets, extra_variables=extra_variables, ) @@ -356,9 +358,6 @@ def __init__( self, jinja: str, num_shots: Union[int, Literal["max"]], - name: Optional[str] = None, - reference: Optional[str] = None, - metadata: Optional["Template.Metadata"] = None, keep_last: bool = False, output_source_field_name: str = "source", output_target_field_name: str = "target", @@ -424,13 +423,6 @@ def __init__( f"the jinja template must contain the variable {VARSHOTS}" ) - template = Template( - jinja=jinja, - name=name, - reference=(reference or get_name_and_version()), - metadata=metadata, - ) - # mypy complains if we don't retype num_shots self.num_shots: Union[int, Literal["max"]] = num_shots @@ -439,11 +431,9 @@ def __init__( self.keep_last: bool = keep_last or num_shots == "max" super().__init__( - template=template, + template=jinja, output_source_field_name=output_source_field_name, output_target_field_name=output_target_field_name, - truncate=truncate, - highlight_variables=highlight_variables, return_multiple_targets=return_multiple_targets, extra_variables=extra_variables, ) diff --git a/src/smashed/mappers/text.py b/src/smashed/mappers/text.py index a4a0be3..d5de999 100644 --- a/src/smashed/mappers/text.py +++ b/src/smashed/mappers/text.py @@ -71,7 +71,7 @@ def __init__( fields: Union[str, Sequence[str]], splitter: Literal[ "blingfire", "whitespace", "whitespace_plus" - ] = "whitespace", + ] = "whitespace_plus", ): if splitter == "blingfire": self.splitter = BlingFireSplitter() diff --git a/src/smashed/recipes/promptsource.py b/src/smashed/recipes/promptsource.py index f2f453c..e44bfe0 100644 --- a/src/smashed/recipes/promptsource.py +++ b/src/smashed/recipes/promptsource.py @@ -151,10 +151,12 @@ def ceil(x): # (e.g., instructions) must be divided over n + 1 sources. actual_source_context_length = ( max_source_length_per_shot + # this is (a) from above - ceil( (max_target_length_per_shot or 0) * (num_shots / (num_shots + 1)) ) + # this is (b) from above - ceil(length_src_prompt / (num_shots + 1)) ) diff --git a/scripts/blingfire-osx-arm64.sh b/src/smashed/utils/install_blingfire_macos.py similarity index 55% rename from scripts/blingfire-osx-arm64.sh rename to src/smashed/utils/install_blingfire_macos.py index fe99564..3ef658f 100644 --- a/scripts/blingfire-osx-arm64.sh +++ b/src/smashed/utils/install_blingfire_macos.py @@ -1,17 +1,16 @@ +#! /usr/bin/env python3 + +from subprocess import call + +BASH_SCRIPT = ''' #! /usr/bin/env bash -# get script directory -SOURCE="${BASH_SOURCE[0]}" -while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink - SCRIPT_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" - SOURCE="$(readlink "$SOURCE")" - # if $SOURCE was a relative symlink, we need to resolve it - # relative to the path where the symlink file was located - [[ $SOURCE != /* ]] && SOURCE="$SCRIPT_DIR/$SOURCE" -done -SCRIPT_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" +# path to the current directory CURRENT_DIR="$(pwd)" +# remove any existing blingfire installation +pip uninstall -y blingfire 2>/dev/null + # clone blingfire repo to a temp directory TMP_DIR=$(mktemp -d) cd $TMP_DIR @@ -36,3 +35,12 @@ # cleanup cd $CURRENT_DIR rm -rf $TMP_DIR +''' + + +def main(): + call(BASH_SCRIPT.strip(), shell=True) + + +if __name__ == "__main__": + main() diff --git a/src/smashed/utils/wordsplitter.py b/src/smashed/utils/wordsplitter.py index c8e0df4..e236c98 100644 --- a/src/smashed/utils/wordsplitter.py +++ b/src/smashed/utils/wordsplitter.py @@ -29,7 +29,11 @@ def __call__( @Necessary( "blingfire", - message="{module_name} missing. Fix with 'pip install smashed[prompting]'", + message=( + "{module_name} missing. Fix with 'pip install smashed[prompting]'" + "or, if you are on a Mac with Apple Silicon chip, " + "'python -m smashed.utils.install_blingfire_macos'." + ) ) class BlingFireSplitter(BaseWordSplitter): def tokenize(self, text: str) -> List[str]: diff --git a/tests/test_promptsource_recipe.py b/tests/test_promptsource_recipe.py index 93e9636..a10a594 100644 --- a/tests/test_promptsource_recipe.py +++ b/tests/test_promptsource_recipe.py @@ -52,12 +52,15 @@ def test_promptsource_recipe(self): recipe = JinjaRecipe( tokenizer=self.tokenizer, jinja_template="Q: {{question}}\nC: {{context}}\nA: |||{{answer}}", - max_source_length_per_shot=15, + # this used to be 15, but now using 'whitespace_plus' tokenizer + # as default, which means that we have a few more tokens in the + # non-variables part of the prompt. + max_source_length_per_shot=18, max_target_length_per_shot=5, ) dataset = [ { - "question": "What is the capital of France?", + "question": "What is the capital of France", "context": "Paris is the capital of " + ("France " * 10), "answer": "Paris " * 10, } @@ -68,7 +71,7 @@ def test_promptsource_recipe(self): self.assertEqual( self.tokenizer.decode(mapped_dataset["input_ids"]), ( - "Q: What is the capital of France? " + "Q: What is the capital of France " "C: Paris is the capital of France " "A:" ),