From bdbc617aa84306419c783d7e188fbe35aea58cc3 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 5 Jan 2023 16:25:56 -0800 Subject: [PATCH] Soldni/stride fix (#39) * fix on how last stride is computed * fixed small issue with collator in case ignore doess not exist * improvements to collator * left padding support * rename support * pretty * tests for tokenizer * fixed issues with datasets 2.8.0 refactor * necessary library syntax fix --- pyproject.toml | 2 +- src/smashed/base/interfaces.py | 18 +++++-- src/smashed/mappers/cache.py | 20 +++++-- src/smashed/mappers/collators.py | 46 ++++++++++++++-- src/smashed/mappers/fields.py | 19 +++++-- src/smashed/mappers/glom.py | 16 ++++-- src/smashed/mappers/multiseq.py | 19 ++++++- src/smashed/mappers/prompting.py | 90 +++++++++++++++++++++++++------- src/smashed/mappers/tokenize.py | 61 +++++++++++++++++----- src/smashed/recipes/collators.py | 6 ++- tests/test_batch_interface.py | 10 +++- tests/test_collators.py | 41 +++++++++++++++ tests/test_hf_pickling.py | 2 +- tests/test_tokenize_mappers.py | 50 ++++++++++++++++++ 14 files changed, 334 insertions(+), 66 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 28f529d..89fa877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.12.0" +version = "0.13.0" description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries." authors = [ {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" }, diff --git a/src/smashed/base/interfaces.py b/src/smashed/base/interfaces.py index db7295e..d8ace4d 100644 --- a/src/smashed/base/interfaces.py +++ b/src/smashed/base/interfaces.py @@ -28,7 +28,15 @@ with necessary("datasets", soft=True) as HUGGINGFACE_DATASET_AVAILABLE: if HUGGINGFACE_DATASET_AVAILABLE or TYPE_CHECKING: - from datasets.arrow_dataset import Batch, Dataset + from datasets.arrow_dataset import Dataset + + try: + from datasets.formatting.formatting import LazyBatch + except ImportError: + # pre datasets 2.8.0 + from datasets.arrow_dataset import ( + Batch as LazyBatch, # pyright: ignore + ) from datasets.iterable_dataset import IterableDataset HuggingFaceDataset = TypeVar( @@ -284,12 +292,12 @@ def _map_huggingface_dataset( else: return transformed_dataset - @map.add_interface(dataset=Batch) + @map.add_interface(dataset=LazyBatch) def _map_huggingface_dataset_batch( self, - dataset: Batch, + dataset: LazyBatch, **map_kwargs: Any, - ) -> Batch: + ) -> LazyBatch: # explicitly casting to a boolean since this is all that is # supported by the simple mapper. # TODO[lucas]: maybe support specifying which fields to keep? @@ -298,7 +306,7 @@ def _map_huggingface_dataset_batch( or self.always_remove_columns ) - dtview: DataBatchView[Batch, str, Any] = DataBatchView(dataset) + dtview: DataBatchView[LazyBatch, str, Any] = DataBatchView(dataset) self._check_fields_datasets( provided_fields=dataset.keys(), diff --git a/src/smashed/mappers/cache.py b/src/smashed/mappers/cache.py index f10d71b..8f1d349 100644 --- a/src/smashed/mappers/cache.py +++ b/src/smashed/mappers/cache.py @@ -13,7 +13,15 @@ with necessary("datasets", soft=True) as HUGGINGFACE_DATASET_AVAILABLE: if HUGGINGFACE_DATASET_AVAILABLE or TYPE_CHECKING: - from datasets.arrow_dataset import Dataset, Batch + from datasets.arrow_dataset import Dataset + + try: + from datasets.formatting.formatting import LazyBatch + except ImportError: + # pre datasets 2.8.0 + from datasets.arrow_dataset import ( + Batch as LazyBatch, # pyright: ignore + ) from datasets.iterable_dataset import IterableDataset from datasets.fingerprint import disable_caching, enable_caching @@ -130,8 +138,8 @@ def get_dataset_fingerprint_hf_iterable( ) return h.hexdigest() - @get_dataset_fingerprint.add_interface(dataset=Batch) - def get_dataset_fingerprint_hf_batch(self, dataset: Batch) -> str: + @get_dataset_fingerprint.add_interface(dataset=LazyBatch) + def get_dataset_fingerprint_hf_batch(self, dataset: LazyBatch) -> str: raise ValueError( "Cannot cache a Batch of a HuggingFace Dataset; please " "cache at the Dataset level instead." @@ -198,7 +206,7 @@ def _save_hf_it(self, dataset: IterableDataset, path: Path): "Saving an IterableDataset is not implemented yet" ) - @save_cache.add_interface(dataset=Batch) + @save_cache.add_interface(dataset=LazyBatch) def _save_hf_batch(self, dataset: Dataset, path: Path): raise ValueError( "Cannot cache a Batch of a HuggingFace Dataset; please " @@ -274,7 +282,9 @@ def _load_list( if HUGGINGFACE_DATASET_AVAILABLE: - @load_cache.add_interface(dataset=(IterableDataset, Dataset, Batch)) + @load_cache.add_interface( + dataset=(IterableDataset, Dataset, LazyBatch) + ) def _load_hf( self, path: Path, diff --git a/src/smashed/mappers/collators.py b/src/smashed/mappers/collators.py index 7c78373..17aa752 100644 --- a/src/smashed/mappers/collators.py +++ b/src/smashed/mappers/collators.py @@ -23,6 +23,7 @@ def __init__( pad_to_length: Optional[Union[int, Sequence[int]]] = None, fields_pad_ids: Optional[Mapping[str, int]] = None, unk_fields_pad_id: Optional[int] = None, + left_pad_fields: Optional[Sequence[str]] = None, ): """Create a collator. @@ -41,10 +42,14 @@ def __init__( unk_fields_pad_id (int, optional): The padding value to use for any field that is not in fields_pad_ids. If not provided, an error will be raised if a field is not in fields_pad_ids. + left_pad_fields (Sequence[str], optional): A list of fields to + pad from the left instead of the right. By default, all fields + are padded from the right. """ self.fields_pad_ids = fields_pad_ids or {} self.pad_to_length = pad_to_length self.unk_fields_pad_id = unk_fields_pad_id + self.left_pad_fields = set(left_pad_fields or []) if self.unk_fields_pad_id is None and self.fields_pad_ids is None: raise ValueError( @@ -145,7 +150,25 @@ def _pad( pad_value: int, dim: int = 0, pad_to_length: Optional[Union[int, Sequence[int]]] = None, + right_pad: bool = True, ) -> torch.Tensor: + """Pad a sequence of tensors to the same length. + + Args: + sequence (Sequence[torch.Tensor]): The sequence of tensors to pad. + It is assumed that all tensors in the sequence have the same + type; if not an error might be raised somewhere. + pad_value (int): The value to use for padding. + dim (int, optional): The dimension we are collating on. Defaults + to 0. + pad_to_length (Union[int, Sequence[int]], optional): If provided, + pad all sequences to this length. If provided as a sequence, + we assume we should pad each dimension to the corresponding + length. If None, sequences will be padded to the length of the + longest sequence. Defaults to None. + right_pad (bool, optional): If True, pad to the right. If False, + pad to the left. Defaults to True. + """ # make sure type of input is right if not ( @@ -192,13 +215,14 @@ def _pad( pad_shapes = tuple( tuple( chain.from_iterable( - (0, m - s) + (0, m - s) if right_pad else (m - s, 0) for s, m in zip(t.size()[::-1], max_lengths[::-1]) ) ) # we do padding shapes for each tensor for t in sequence ) + # call each pad on each of the tensors with the appropriate padding to_stack = tuple( torch.nn.functional.pad( @@ -218,6 +242,7 @@ def transform( # type: ignore sequence=list_of_tensors, pad_value=self._get_padding_value(field_name=field_name), pad_to_length=self.pad_to_length, + right_pad=(field_name not in self.left_pad_fields), ) for field_name, list_of_tensors in data.items() } @@ -270,13 +295,17 @@ def _get_list_shape_recursive( # this iterator will yield the shape of each element in the sequence inner_dims = (self._get_list_shape_recursive(s) for s in sequence) - # the acutal shape is the maximum of the inner dims + # the actual shape is the maximum of the inner dims inner_shape = tuple(max(dims) for dims in zip(*inner_dims)) return (len(sequence), *inner_shape) def _pad_recursive( - self, sequence: List[Any], shape: Sequence[int], padding_symbol: Any + self, + sequence: List[Any], + shape: Sequence[int], + padding_symbol: Any, + pad_right: bool = True, ) -> List[Any]: """Recursively pads a list of [lists, ...]. @@ -284,9 +313,11 @@ def _pad_recursive( sequence (List[Any]): The list to pad. shape (Sequence[int]): The shape to pad to. padding_symbol (Any): The symbol to pad with. + pad_right (bool, optional): If True, pads to the right. If False, + pads to the left. Defaults to True. Returns: - List[Any]: _description_ + List[Any]: The padded list. """ if len(shape) < 2: @@ -321,7 +352,11 @@ def _pad_recursive( # # We do that in the following line: sequence_with_brand_new_padding = ( + # the side we pad depends on wether pad_right is True or False sub_seq + [nested_pad_symbol] * (dim_to_pad_shape - len(sub_seq)) + if pad_right + else [nested_pad_symbol] * (dim_to_pad_shape - len(sub_seq)) + + sub_seq for sub_seq in sequence ) @@ -342,6 +377,7 @@ def _pad( self: "ListCollatorMapper", seq_of_seq_to_pad: List[Any], padding_symbol: Any, + pad_right: bool = True, ) -> List[Any]: padding_shape = self._get_list_shape_recursive(seq_of_seq_to_pad) @@ -367,6 +403,7 @@ def _pad( sequence=seq_of_seq_to_pad, shape=padding_shape, padding_symbol=padding_symbol, + pad_right=pad_right, ) return padded_sequence @@ -377,6 +414,7 @@ def transform(self, data: TransformElementType) -> TransformElementType: field_name: self._pad( seq_of_seq_to_pad=field_value, padding_symbol=self._get_padding_value(field_name=field_name), + pad_right=(field_name not in self.left_pad_fields), ) for field_name, field_value in data.items() } diff --git a/src/smashed/mappers/fields.py b/src/smashed/mappers/fields.py index f55c70c..d680505 100644 --- a/src/smashed/mappers/fields.py +++ b/src/smashed/mappers/fields.py @@ -27,6 +27,7 @@ def __init__( self, keep_fields: Optional[List[str]] = None, drop_fields: Optional[List[str]] = None, + raise_on_missing: bool = True, ): """ Args: @@ -34,6 +35,8 @@ def __init__( are dropped. Defaults to []. drop_fields (List[str]): Fields to drop, all other fields are kept. Defaults to []. + raise_on_missing (bool): Whether to raise an error if a field + is missing. Defaults to True. """ # xor between keep_fields and remove_fields @@ -42,16 +45,22 @@ def __init__( ): raise ValueError("Must specify `keep_fields` or `drop_fields`") - super().__init__(input_fields=drop_fields, output_fields=keep_fields) + self.keep_fields = dict.fromkeys(keep_fields) if keep_fields else None + self.drop_fields = dict.fromkeys(drop_fields) if drop_fields else None + + super().__init__( + input_fields=drop_fields if raise_on_missing else None, + output_fields=keep_fields if raise_on_missing else None, + ) def transform(self, data: TransformElementType) -> TransformElementType: - if self.input_fields: + if self.drop_fields: new_data = { - k: v for k, v in data.items() if k not in self.input_fields + k: v for k, v in data.items() if k not in self.drop_fields } - elif self.output_fields: - new_data = {k: data[k] for k in self.output_fields} + elif self.keep_fields: + new_data = {k: data[k] for k in data if k in self.keep_fields} else: raise ValueError("Must specify `keep_fields` or `drop_fields`") diff --git a/src/smashed/mappers/glom.py b/src/smashed/mappers/glom.py index 76991d7..1057b7e 100644 --- a/src/smashed/mappers/glom.py +++ b/src/smashed/mappers/glom.py @@ -8,7 +8,13 @@ with necessary("datasets", soft=True) as DATASETS_AVAILABLE: if DATASETS_AVAILABLE: - from datasets.arrow_dataset import Example + try: + from datasets.formatting.formatting import LazyRow + except ImportError: + # pre datasets 2.8.0 + from datasets.arrow_dataset import ( + Example as LazyRow, # pyright: ignore + ) class ExtendGlommerMixin: @@ -26,10 +32,10 @@ def glommer(self) -> glom.Glommer: if DATASETS_AVAILABLE: glommer.register( - target_type=Example, - get=Example.__getitem__, - iter=Example.__iter__, - exact=Example.__eq__, + target_type=LazyRow, + get=LazyRow.__getitem__, + iter=LazyRow.__iter__, + exact=LazyRow.__eq__, ) glommer.register( diff --git a/src/smashed/mappers/multiseq.py b/src/smashed/mappers/multiseq.py index 47230e3..e6b7213 100644 --- a/src/smashed/mappers/multiseq.py +++ b/src/smashed/mappers/multiseq.py @@ -396,6 +396,7 @@ def transform( ) >= self.max_stride_count if stride_too_long or stride_has_too_many_seqs: + yield { k: ( # if a list of fields to strides has been provided, @@ -423,7 +424,21 @@ def transform( cumulative_stride_length += current_seq_length # yield the last sequence - out = {k: v[seq_pos_start:] for k, v in sample.items()} + out = { + k: ( + # same logic as above: if a list of fields to strides + # has been provided, then only stride this field if it + # is in the list and duplicate if it is not; if no list + # of fields to stride has been provided, then stride all. + v[seq_pos_start:] + if ( + self.fields_to_stride is None + or k in self.fields_to_stride + ) + else v + ) + for k, v in sample.items() + } yield out @@ -434,7 +449,7 @@ def __init__( single_value_field: str, like_field: str = "input_ids", strategy: Literal["first", "last", "all"] = "first", - padding_id: Union[int, float] = -100, + padding_id: Any = -100, ) -> None: """Mapper to create a sequence of values from single value. Useful when casting a sequence classification task to a sequence diff --git a/src/smashed/mappers/prompting.py b/src/smashed/mappers/prompting.py index 67ac0ee..3987bbc 100644 --- a/src/smashed/mappers/prompting.py +++ b/src/smashed/mappers/prompting.py @@ -8,7 +8,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from ..base import SingleBaseMapper, TransformElementType -from .tokenize import GetTokenizerOutputFieldsMixin +from .tokenize import GetTokenizerOutputFieldsAndNamesMixIn __all__ = [ "EncodeFieldsMapper", @@ -401,42 +401,92 @@ def transform(self, data: TransformElementType) -> TransformElementType: return data -class FillEncodedPromptMapper(SingleBaseMapper, GetTokenizerOutputFieldsMixin): +class FillEncodedPromptMapper( + SingleBaseMapper, GetTokenizerOutputFieldsAndNamesMixIn +): + """Fills a prompt template with already encoded (i.e., tokenized and turned + into ids data.""" + def __init__( self, template: str, - tokenizer: PreTrainedTokenizerBase, + tokenizer: Optional[PreTrainedTokenizerBase] = None, output_prefix: Optional[str] = None, + output_rename_map: Optional[Dict[str, str]] = None, return_attention_mask: bool = True, return_token_type_ids: bool = False, add_bos_token: bool = True, add_eos_token: bool = True, ) -> None: - self.tokenizer = tokenizer - self._prefix = output_prefix + """ + Args: + template (str): The template to fill. It should be a string with + placeholders for the fields to fill. + tokenizer (Optional[PreTrainedTokenizerBase]): The tokenizer used + to encode the prompt. It is used to lookup the bos and eos + tokens. If add_bos_token or add_eos_token are False, this + can be None. Defaults to None. + output_prefix (Optional[str]): A prefix to add to all output + fields. Defaults to None. An error will be raised if both + output_prefix and output_rename_map are provided. + output_rename_map (Optional[Dict[str, str]]): A map that specifies + how fields in the tokenizers should be renamed. If None, the + fields will not be renamed. Defaults to None. An error will + be raised if both output_prefix and output_rename_map are + provided. + return_attention_mask (bool): Whether to return the attention mask. + Defaults to True. + return_token_type_ids (bool): Whether to return the token type ids. + Defaults to False. + add_bos_token (bool): Whether to add the bos token to the prompt. + Defaults to True. + add_eos_token (bool): Whether to add the eos token to the prompt. + Defaults to True. + """ + GetTokenizerOutputFieldsAndNamesMixIn.__init__( + self, + output_rename_map=output_rename_map, + output_prefix=output_prefix, + ) self.return_attention_mask = return_attention_mask self.return_token_type_ids = return_token_type_ids - self.bos_token_ids = ( - [] - if tokenizer.bos_token_id is None or not add_bos_token - else [tokenizer.bos_token_id] - ) - self.eos_token_ids = ( - [] - if tokenizer.eos_token_id is None or not add_eos_token - else [tokenizer.eos_token_id] - ) + if add_bos_token: + if tokenizer is None: + raise ValueError( + "Cannot add bos token if no tokenizer is provided." + ) + self.bos_token_ids = ( + [tokenizer.bos_token_id] + if tokenizer.bos_token_id is not None + else [] + ) + else: + self.bos_token_ids = [] + + if add_eos_token: + if tokenizer is None: + raise ValueError( + "Cannot add eos token if no tokenizer is provided." + ) + self.eos_token_ids = ( + [tokenizer.eos_token_id] + if tokenizer.eos_token_id is not None + else [] + ) + else: + self.eos_token_ids = [] self.prompt = PromptSegment.from_template( template=template, tokenizer=tokenizer ) - super().__init__( + SingleBaseMapper.__init__( + self, input_fields=[p.field_name for p in self.prompt if p.field_name], output_fields=[ - self.prefix(field_name) + self.fname(field_name) for field_name in self.output_fields_from_tokenizer_kwargs( tokenizer_kwargs={ "return_attention_mask": return_attention_mask, @@ -453,10 +503,10 @@ def transform(self, data: TransformElementType) -> TransformElementType: + self.eos_token_ids ) - output = {self.prefix("input_ids"): encoded_prompt} + output = {self.fname("input_ids"): encoded_prompt} if self.return_attention_mask: - output[self.prefix("attention_mask")] = [1] * len(encoded_prompt) + output[self.fname("attention_mask")] = [1] * len(encoded_prompt) if self.return_token_type_ids: - output[self.prefix("token_type_ids")] = [0] * len(encoded_prompt) + output[self.fname("token_type_ids")] = [0] * len(encoded_prompt) return output diff --git a/src/smashed/mappers/tokenize.py b/src/smashed/mappers/tokenize.py index 5f5f514..880eda4 100644 --- a/src/smashed/mappers/tokenize.py +++ b/src/smashed/mappers/tokenize.py @@ -19,15 +19,28 @@ ] -class GetTokenizerOutputFieldsMixin: +class GetTokenizerOutputFieldsAndNamesMixIn: """A mixin class that figures out the output fields based on the arguments that will be passed a to tokenizer.__call__ method.""" tokenizer: PreTrainedTokenizerBase _prefix: Optional[str] + def __init__( + self, + output_prefix: Optional[str] = None, + output_rename_map: Optional[Dict[str, str]] = None, + ): + assert ( + output_prefix is None or output_rename_map is None + ), "You cannot specify both output_prefix and output_rename_map." + + self._output_prefix = output_prefix + self._output_rename_map = output_rename_map + + @staticmethod def output_fields_from_tokenizer_kwargs( - self, tokenizer_kwargs: Optional[dict] = None + tokenizer_kwargs: Optional[dict] = None, ) -> List[str]: tokenizer_kwargs = tokenizer_kwargs or {} @@ -49,15 +62,21 @@ def output_fields_from_tokenizer_kwargs( return output_fields - def prefix(self, field_or_dict: str) -> str: - return ( - f"{self._prefix}_{field_or_dict}" - if self._prefix - else field_or_dict - ) + def fname(self, field_or_dict: str) -> str: + if self._output_prefix: + return f"{self._output_prefix}_{field_or_dict}" + elif self._output_rename_map: + if field_or_dict in self._output_rename_map: + return self._output_rename_map[field_or_dict] + else: + raise ValueError( + f"Field '{field_or_dict}' is not in the rename map." + ) + else: + return field_or_dict -class TokenizerMapper(SingleBaseMapper, GetTokenizerOutputFieldsMixin): +class TokenizerMapper(SingleBaseMapper, GetTokenizerOutputFieldsAndNamesMixIn): """Tokenize a field using a tokenizer.""" def __init__( @@ -65,6 +84,7 @@ def __init__( tokenizer: PreTrainedTokenizerBase, input_field: str, output_prefix: Optional[str] = None, + output_rename_map: Optional[Dict[str, str]] = None, add_special_tokens: Optional[bool] = True, max_length: Optional[int] = None, is_split_into_words: Optional[bool] = False, @@ -84,7 +104,13 @@ def __init__( huggingface/transformers library. input_field (str): The field to tokenize. output_prefix (Optional[str], optional): A prefix to add to all - output fields. Defaults to None. + output fields. Defaults to None. An error will be raised if + both output_prefix and output_rename_map are provided. + output_rename_map (Optional[Dict[str, str]], optional): A map + that specifies how fields in the tokenizers should be renamed. + If None, the fields will not be renamed. Defaults to None. + An error will be raised if both output_prefix and + output_rename_map are provided. add_special_tokens (Optional[bool], optional): Whether or not to add special tokens to the input. Defaults to True. max_length (Optional[int], optional): The maximum length of the @@ -112,9 +138,15 @@ def __init__( to the tokenizer; these will override the above arguments. """ + # this deal with names of fields to expect + GetTokenizerOutputFieldsAndNamesMixIn.__init__( + self, + output_prefix=output_prefix, + output_rename_map=output_rename_map, + ) + self.to_tokenize_filed = input_field self.tokenizer = tokenizer - self._prefix = output_prefix # arguments to be passed to the tokenizer __call__ function go here tokenizer_kwargs = { @@ -147,9 +179,10 @@ def __init__( self.tokenize_kwargs = tokenizer_kwargs - super().__init__( + SingleBaseMapper.__init__( + self, input_fields=[self.to_tokenize_filed], - output_fields=list(map(self.prefix, output_fields)), + output_fields=list(map(self.fname, output_fields)), ) def transform(self, data: TransformElementType) -> TransformElementType: @@ -186,7 +219,7 @@ def transform(self, data: TransformElementType) -> TransformElementType: ] return { - self.prefix(field_name): field_value + self.fname(field_name): field_value for field_name, field_value in batch_encoding.items() } diff --git a/src/smashed/recipes/collators.py b/src/smashed/recipes/collators.py index 3e66721..d856942 100644 --- a/src/smashed/recipes/collators.py +++ b/src/smashed/recipes/collators.py @@ -25,7 +25,7 @@ def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, List[Any]]: # skip fields that do not support collation as tensors; we will # reinsert them later as lists in each batch skipped: Dict[str, List[Any]] = { - field: [sample.pop(field) for sample in batch] + field: [sample.pop(field) for sample in batch if field in sample] for field in self.do_not_collate } @@ -40,7 +40,9 @@ def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, List[Any]]: collated_batch: Dict[str, List[Any]] = out[0] # here we reattach the answers to the batch - collated_batch.update(skipped) + # "if v" prevents us from adding empty lists, + # which correspond to fields that were not present in the batch + collated_batch.update({k: v for k, v in skipped.items() if v}) return collated_batch diff --git a/tests/test_batch_interface.py b/tests/test_batch_interface.py index 74aca08..ed9149f 100644 --- a/tests/test_batch_interface.py +++ b/tests/test_batch_interface.py @@ -2,7 +2,13 @@ from copy import deepcopy from functools import partial -from datasets.arrow_dataset import Batch, Dataset +from datasets.arrow_dataset import Dataset + +try: + from datasets.formatting.formatting import LazyBatch +except ImportError: + # pre datasets 2.8.0 + from datasets.arrow_dataset import Batch as LazyBatch # pyright: ignore from smashed.mappers.debug import MockMapper @@ -13,7 +19,7 @@ def test_batch(self, remove_columns: bool = False): data = Dataset.from_list([{"a": i, "b": i ** 2} for i in range(100)]) - def _batch_fn(data: Batch, mapper: MockMapper) -> Batch: + def _batch_fn(data: LazyBatch, mapper: MockMapper) -> LazyBatch: return mapper.map(deepcopy(data), remove_columns=remove_columns) fn = partial(_batch_fn, mapper=mapper) diff --git a/tests/test_collators.py b/tests/test_collators.py index c3b80bc..7cc4707 100644 --- a/tests/test_collators.py +++ b/tests/test_collators.py @@ -84,6 +84,26 @@ def test_nested_collators(self): self.assertEqual(grouped_a[1][1], [5.0, 5.1, -1, -1, -1]) self.assertEqual(grouped_a[1][2], [-1, -1, -1, -1, -1]) + def test_left_padding(self): + dataset = [ + {"a": [1, 2, 3]}, + {"a": [4, 5]}, + {"a": [6, 7, 8, 9, 10]}, + ] + pipeline = FixedBatchSizeMapper( + batch_size="max" + ) >> ListCollatorMapper( + fields_pad_ids={"a": -1}, left_pad_fields=["a"] + ) + + output = pipeline.map(dataset) + + self.assertEqual(len(output[0]["a"]), 3) + self.assertEqual([len(s) for s in output[0]["a"]], [5, 5, 5]) + self.assertEqual(output[0]["a"][0], [-1, -1, 1, 2, 3]) + self.assertEqual(output[0]["a"][1], [-1, -1, -1, 4, 5]) + self.assertEqual(output[0]["a"][2], [6, 7, 8, 9, 10]) + class TestTensorCollators(unittest.TestCase): def test_base_collator(self): @@ -134,3 +154,24 @@ def test_from_tokenizer_collator(self): # same thing except attention mask uses 0 for padding self.assertEqual((collated_dataset[0]["attention_mask"] == 0).sum(), 4) + + def test_left_padding(self): + dataset = [ + {"a": [1, 2, 3]}, + {"a": [4, 5]}, + {"a": [6, 7, 8, 9, 10]}, + ] + pipeline = ( + Python2TorchMapper() + >> FixedBatchSizeMapper(batch_size="max") + >> TensorCollatorMapper( + fields_pad_ids={"a": -1}, left_pad_fields=["a"] + ) + ) + + output = pipeline.map(dataset) + + self.assertEqual(output[0]["a"].shape, (3, 5)) + self.assertEqual(output[0]["a"][0].tolist(), [-1, -1, 1, 2, 3]) + self.assertEqual(output[0]["a"][1].tolist(), [-1, -1, -1, 4, 5]) + self.assertEqual(output[0]["a"][2].tolist(), [6, 7, 8, 9, 10]) diff --git a/tests/test_hf_pickling.py b/tests/test_hf_pickling.py index d3257f8..d261a13 100644 --- a/tests/test_hf_pickling.py +++ b/tests/test_hf_pickling.py @@ -14,7 +14,7 @@ ) from smashed.mappers.debug import MockMapper -with necessary(("datasets", "dill")): +with necessary(["datasets", "dill"]): import dill from datasets.arrow_dataset import Dataset from datasets.fingerprint import Hasher diff --git a/tests/test_tokenize_mappers.py b/tests/test_tokenize_mappers.py index e94c10d..4f4e5d6 100644 --- a/tests/test_tokenize_mappers.py +++ b/tests/test_tokenize_mappers.py @@ -450,3 +450,53 @@ def test_return_words(self): None, ], ) + + def test_prefix(self): + mapper = TokenizerMapper( + input_field="text", + tokenizer=self.tokenizer, + return_attention_mask=False, + output_prefix="test", + ) + + dataset = [ + {"text": "This is a sentence."}, + ] + + new_dataset = mapper.map(dataset) + self.assertEqual("test_input_ids" in new_dataset[0], True) + self.assertEqual( + new_dataset[0]["test_input_ids"], + [102, 238, 165, 106, 8517, 205, 103], + ) + + def test_rename(self): + mapper = TokenizerMapper( + input_field="text", + tokenizer=self.tokenizer, + return_attention_mask=True, + output_rename_map={"input_ids": "foo", "attention_mask": "bar"}, + ) + + dataset = [ + {"text": "This is a sentence."}, + ] + + new_dataset = mapper.map(dataset) + self.assertTrue("foo" in new_dataset[0]) + self.assertTrue("bar" in new_dataset[0]) + self.assertFalse("input_ids" in new_dataset[0]) + self.assertFalse("attention_mask" in new_dataset[0]) + + with self.assertRaises(ValueError): + mapper = TokenizerMapper( + input_field="text", + tokenizer=self.tokenizer, + return_attention_mask=True, + return_token_type_ids=True, + output_rename_map={ + "input_ids": "foo", + "attention_mask": "bar", + }, + ) + mapper.map(dataset)