Skip to content

Commit

Permalink
Soldni/stride fix (#39)
Browse files Browse the repository at this point in the history
* fix on how last stride is computed

* fixed small issue with collator in case ignore doess not exist

* improvements to collator

* left padding support

* rename support

* pretty

* tests for tokenizer

* fixed issues with datasets 2.8.0 refactor

* necessary library syntax fix
  • Loading branch information
soldni authored Jan 6, 2023
1 parent 201e28d commit bdbc617
Show file tree
Hide file tree
Showing 14 changed files with 334 additions and 66 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.12.0"
version = "0.13.0"
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "[email protected]" },
Expand Down
18 changes: 13 additions & 5 deletions src/smashed/base/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@

with necessary("datasets", soft=True) as HUGGINGFACE_DATASET_AVAILABLE:
if HUGGINGFACE_DATASET_AVAILABLE or TYPE_CHECKING:
from datasets.arrow_dataset import Batch, Dataset
from datasets.arrow_dataset import Dataset

try:
from datasets.formatting.formatting import LazyBatch
except ImportError:
# pre datasets 2.8.0
from datasets.arrow_dataset import (
Batch as LazyBatch, # pyright: ignore
)
from datasets.iterable_dataset import IterableDataset

HuggingFaceDataset = TypeVar(
Expand Down Expand Up @@ -284,12 +292,12 @@ def _map_huggingface_dataset(
else:
return transformed_dataset

@map.add_interface(dataset=Batch)
@map.add_interface(dataset=LazyBatch)
def _map_huggingface_dataset_batch(
self,
dataset: Batch,
dataset: LazyBatch,
**map_kwargs: Any,
) -> Batch:
) -> LazyBatch:
# explicitly casting to a boolean since this is all that is
# supported by the simple mapper.
# TODO[lucas]: maybe support specifying which fields to keep?
Expand All @@ -298,7 +306,7 @@ def _map_huggingface_dataset_batch(
or self.always_remove_columns
)

dtview: DataBatchView[Batch, str, Any] = DataBatchView(dataset)
dtview: DataBatchView[LazyBatch, str, Any] = DataBatchView(dataset)

self._check_fields_datasets(
provided_fields=dataset.keys(),
Expand Down
20 changes: 15 additions & 5 deletions src/smashed/mappers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@

with necessary("datasets", soft=True) as HUGGINGFACE_DATASET_AVAILABLE:
if HUGGINGFACE_DATASET_AVAILABLE or TYPE_CHECKING:
from datasets.arrow_dataset import Dataset, Batch
from datasets.arrow_dataset import Dataset

try:
from datasets.formatting.formatting import LazyBatch
except ImportError:
# pre datasets 2.8.0
from datasets.arrow_dataset import (
Batch as LazyBatch, # pyright: ignore
)
from datasets.iterable_dataset import IterableDataset
from datasets.fingerprint import disable_caching, enable_caching

Expand Down Expand Up @@ -130,8 +138,8 @@ def get_dataset_fingerprint_hf_iterable(
)
return h.hexdigest()

@get_dataset_fingerprint.add_interface(dataset=Batch)
def get_dataset_fingerprint_hf_batch(self, dataset: Batch) -> str:
@get_dataset_fingerprint.add_interface(dataset=LazyBatch)
def get_dataset_fingerprint_hf_batch(self, dataset: LazyBatch) -> str:
raise ValueError(
"Cannot cache a Batch of a HuggingFace Dataset; please "
"cache at the Dataset level instead."
Expand Down Expand Up @@ -198,7 +206,7 @@ def _save_hf_it(self, dataset: IterableDataset, path: Path):
"Saving an IterableDataset is not implemented yet"
)

@save_cache.add_interface(dataset=Batch)
@save_cache.add_interface(dataset=LazyBatch)
def _save_hf_batch(self, dataset: Dataset, path: Path):
raise ValueError(
"Cannot cache a Batch of a HuggingFace Dataset; please "
Expand Down Expand Up @@ -274,7 +282,9 @@ def _load_list(

if HUGGINGFACE_DATASET_AVAILABLE:

@load_cache.add_interface(dataset=(IterableDataset, Dataset, Batch))
@load_cache.add_interface(
dataset=(IterableDataset, Dataset, LazyBatch)
)
def _load_hf(
self,
path: Path,
Expand Down
46 changes: 42 additions & 4 deletions src/smashed/mappers/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
fields_pad_ids: Optional[Mapping[str, int]] = None,
unk_fields_pad_id: Optional[int] = None,
left_pad_fields: Optional[Sequence[str]] = None,
):
"""Create a collator.
Expand All @@ -41,10 +42,14 @@ def __init__(
unk_fields_pad_id (int, optional): The padding value to use for
any field that is not in fields_pad_ids. If not provided, an
error will be raised if a field is not in fields_pad_ids.
left_pad_fields (Sequence[str], optional): A list of fields to
pad from the left instead of the right. By default, all fields
are padded from the right.
"""
self.fields_pad_ids = fields_pad_ids or {}
self.pad_to_length = pad_to_length
self.unk_fields_pad_id = unk_fields_pad_id
self.left_pad_fields = set(left_pad_fields or [])

if self.unk_fields_pad_id is None and self.fields_pad_ids is None:
raise ValueError(
Expand Down Expand Up @@ -145,7 +150,25 @@ def _pad(
pad_value: int,
dim: int = 0,
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
right_pad: bool = True,
) -> torch.Tensor:
"""Pad a sequence of tensors to the same length.
Args:
sequence (Sequence[torch.Tensor]): The sequence of tensors to pad.
It is assumed that all tensors in the sequence have the same
type; if not an error might be raised somewhere.
pad_value (int): The value to use for padding.
dim (int, optional): The dimension we are collating on. Defaults
to 0.
pad_to_length (Union[int, Sequence[int]], optional): If provided,
pad all sequences to this length. If provided as a sequence,
we assume we should pad each dimension to the corresponding
length. If None, sequences will be padded to the length of the
longest sequence. Defaults to None.
right_pad (bool, optional): If True, pad to the right. If False,
pad to the left. Defaults to True.
"""

# make sure type of input is right
if not (
Expand Down Expand Up @@ -192,13 +215,14 @@ def _pad(
pad_shapes = tuple(
tuple(
chain.from_iterable(
(0, m - s)
(0, m - s) if right_pad else (m - s, 0)
for s, m in zip(t.size()[::-1], max_lengths[::-1])
)
)
# we do padding shapes for each tensor
for t in sequence
)

# call each pad on each of the tensors with the appropriate padding
to_stack = tuple(
torch.nn.functional.pad(
Expand All @@ -218,6 +242,7 @@ def transform( # type: ignore
sequence=list_of_tensors,
pad_value=self._get_padding_value(field_name=field_name),
pad_to_length=self.pad_to_length,
right_pad=(field_name not in self.left_pad_fields),
)
for field_name, list_of_tensors in data.items()
}
Expand Down Expand Up @@ -270,23 +295,29 @@ def _get_list_shape_recursive(
# this iterator will yield the shape of each element in the sequence
inner_dims = (self._get_list_shape_recursive(s) for s in sequence)

# the acutal shape is the maximum of the inner dims
# the actual shape is the maximum of the inner dims
inner_shape = tuple(max(dims) for dims in zip(*inner_dims))

return (len(sequence), *inner_shape)

def _pad_recursive(
self, sequence: List[Any], shape: Sequence[int], padding_symbol: Any
self,
sequence: List[Any],
shape: Sequence[int],
padding_symbol: Any,
pad_right: bool = True,
) -> List[Any]:
"""Recursively pads a list of [lists, ...].
Args:
sequence (List[Any]): The list to pad.
shape (Sequence[int]): The shape to pad to.
padding_symbol (Any): The symbol to pad with.
pad_right (bool, optional): If True, pads to the right. If False,
pads to the left. Defaults to True.
Returns:
List[Any]: _description_
List[Any]: The padded list.
"""

if len(shape) < 2:
Expand Down Expand Up @@ -321,7 +352,11 @@ def _pad_recursive(
#
# We do that in the following line:
sequence_with_brand_new_padding = (
# the side we pad depends on wether pad_right is True or False
sub_seq + [nested_pad_symbol] * (dim_to_pad_shape - len(sub_seq))
if pad_right
else [nested_pad_symbol] * (dim_to_pad_shape - len(sub_seq))
+ sub_seq
for sub_seq in sequence
)

Expand All @@ -342,6 +377,7 @@ def _pad(
self: "ListCollatorMapper",
seq_of_seq_to_pad: List[Any],
padding_symbol: Any,
pad_right: bool = True,
) -> List[Any]:

padding_shape = self._get_list_shape_recursive(seq_of_seq_to_pad)
Expand All @@ -367,6 +403,7 @@ def _pad(
sequence=seq_of_seq_to_pad,
shape=padding_shape,
padding_symbol=padding_symbol,
pad_right=pad_right,
)
return padded_sequence

Expand All @@ -377,6 +414,7 @@ def transform(self, data: TransformElementType) -> TransformElementType:
field_name: self._pad(
seq_of_seq_to_pad=field_value,
padding_symbol=self._get_padding_value(field_name=field_name),
pad_right=(field_name not in self.left_pad_fields),
)
for field_name, field_value in data.items()
}
Expand Down
19 changes: 14 additions & 5 deletions src/smashed/mappers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ def __init__(
self,
keep_fields: Optional[List[str]] = None,
drop_fields: Optional[List[str]] = None,
raise_on_missing: bool = True,
):
"""
Args:
keep_fields (List[str]): Fields to keep, all other fields
are dropped. Defaults to [].
drop_fields (List[str]): Fields to drop, all other fields
are kept. Defaults to [].
raise_on_missing (bool): Whether to raise an error if a field
is missing. Defaults to True.
"""

# xor between keep_fields and remove_fields
Expand All @@ -42,16 +45,22 @@ def __init__(
):
raise ValueError("Must specify `keep_fields` or `drop_fields`")

super().__init__(input_fields=drop_fields, output_fields=keep_fields)
self.keep_fields = dict.fromkeys(keep_fields) if keep_fields else None
self.drop_fields = dict.fromkeys(drop_fields) if drop_fields else None

super().__init__(
input_fields=drop_fields if raise_on_missing else None,
output_fields=keep_fields if raise_on_missing else None,
)

def transform(self, data: TransformElementType) -> TransformElementType:
if self.input_fields:
if self.drop_fields:
new_data = {
k: v for k, v in data.items() if k not in self.input_fields
k: v for k, v in data.items() if k not in self.drop_fields
}

elif self.output_fields:
new_data = {k: data[k] for k in self.output_fields}
elif self.keep_fields:
new_data = {k: data[k] for k in data if k in self.keep_fields}

else:
raise ValueError("Must specify `keep_fields` or `drop_fields`")
Expand Down
16 changes: 11 additions & 5 deletions src/smashed/mappers/glom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

with necessary("datasets", soft=True) as DATASETS_AVAILABLE:
if DATASETS_AVAILABLE:
from datasets.arrow_dataset import Example
try:
from datasets.formatting.formatting import LazyRow
except ImportError:
# pre datasets 2.8.0
from datasets.arrow_dataset import (
Example as LazyRow, # pyright: ignore
)


class ExtendGlommerMixin:
Expand All @@ -26,10 +32,10 @@ def glommer(self) -> glom.Glommer:

if DATASETS_AVAILABLE:
glommer.register(
target_type=Example,
get=Example.__getitem__,
iter=Example.__iter__,
exact=Example.__eq__,
target_type=LazyRow,
get=LazyRow.__getitem__,
iter=LazyRow.__iter__,
exact=LazyRow.__eq__,
)

glommer.register(
Expand Down
19 changes: 17 additions & 2 deletions src/smashed/mappers/multiseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def transform(
) >= self.max_stride_count

if stride_too_long or stride_has_too_many_seqs:

yield {
k: (
# if a list of fields to strides has been provided,
Expand Down Expand Up @@ -423,7 +424,21 @@ def transform(
cumulative_stride_length += current_seq_length

# yield the last sequence
out = {k: v[seq_pos_start:] for k, v in sample.items()}
out = {
k: (
# same logic as above: if a list of fields to strides
# has been provided, then only stride this field if it
# is in the list and duplicate if it is not; if no list
# of fields to stride has been provided, then stride all.
v[seq_pos_start:]
if (
self.fields_to_stride is None
or k in self.fields_to_stride
)
else v
)
for k, v in sample.items()
}

yield out

Expand All @@ -434,7 +449,7 @@ def __init__(
single_value_field: str,
like_field: str = "input_ids",
strategy: Literal["first", "last", "all"] = "first",
padding_id: Union[int, float] = -100,
padding_id: Any = -100,
) -> None:
"""Mapper to create a sequence of values from single value.
Useful when casting a sequence classification task to a sequence
Expand Down
Loading

0 comments on commit bdbc617

Please sign in to comment.