Skip to content

Commit

Permalink
torch optional (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni authored May 27, 2023
1 parent b0faa0f commit 376f6b3
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 24 deletions.
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.20.0"
version = "0.21.0"
description = """\
SMASHED is a toolkit designed to apply transformations to samples in \
datasets, such as fields extraction, tokenization, prompting, batching, \
Expand All @@ -11,7 +11,6 @@ license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"torch>=1.9",
"necessary>=0.4.1",
"trouting>=0.3.3",
"ftfy>=6.1.1",
Expand Down Expand Up @@ -103,12 +102,17 @@ remote = [
"smart-open>=5.2.1",
"boto3>=1.25.5",
]
torch = [
"torch>=1.9",
]
datasets = [
"smashed[torch]",
"transformers>=4.5",
"datasets>=2.8.0",
"dill>=0.3.0",
]
prompting = [
"smashed[torch]",
"transformers>=4.5",
"promptsource>=0.2.3",
"blingfire>=0.1.8",
Expand All @@ -119,6 +123,7 @@ torchdata = [
]
all = [
"smashed[dev]",
"smashed[torch]",
"smashed[datasets]",
"smashed[torchdata]",
"smashed[remote]",
Expand Down
9 changes: 4 additions & 5 deletions src/smashed/base/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

from necessary import necessary
from torch._utils import classproperty
from trouting import trouting

from .abstract import (
Expand Down Expand Up @@ -52,7 +51,7 @@ class MapMethodInterfaceMixIn(AbstractBaseMapper):
and various interfaces. Do not inherit from this class directly,
but use SingleBaseMapper/BatchedBaseMapper instead."""

@classproperty
@classmethod
def always_remove_columns(cls) -> bool:
"""Whether this mapper should always remove its input columns
from the dataset. If False, the mapper will only remove columns
Expand Down Expand Up @@ -191,7 +190,7 @@ def _map_list_of_dicts(
# TODO[lucas]: maybe support specifying which fields to keep?
remove_columns = (
bool(map_kwargs.get("remove_columns", False))
or self.always_remove_columns
or self.always_remove_columns()
)

if isinstance(dataset, abc.Sequence):
Expand Down Expand Up @@ -258,7 +257,7 @@ def _map_huggingface_dataset(

print_fingerprint = map_kwargs.pop("print_fingerprint", False)

if self.always_remove_columns:
if self.always_remove_columns():
remove_columns = list(dataset.features.keys())
else:
remove_columns = map_kwargs.get("remove_columns", [])
Expand Down Expand Up @@ -320,7 +319,7 @@ def _map_huggingface_dataset_batch(
# TODO[lucas]: maybe support specifying which fields to keep?
remove_columns = (
bool(map_kwargs.get("remove_columns", False))
or self.always_remove_columns
or self.always_remove_columns()
)

dtview: DataBatchView[LazyBatch, str, Any] = DataBatchView(dataset)
Expand Down
19 changes: 14 additions & 5 deletions src/smashed/mappers/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Union,
)

import torch
from necessary import necessary

from ..base import SingleBaseMapper, TransformElementType
Expand All @@ -26,6 +25,10 @@
PreTrainedTokenizerBase,
)

with necessary("torch", soft=True) as PYTORCH_AVAILABLE:
if PYTORCH_AVAILABLE or TYPE_CHECKING:
import torch


__all__ = [
"ListCollatorMapper",
Expand Down Expand Up @@ -170,15 +173,21 @@ class TensorCollatorMapper(BaseCollator, SingleBaseMapper):
>>> data_loader = DataLoader(..., collate_fn=collator.transform)
"""

def __init__(self, *args, **kwargs):
if not PYTORCH_AVAILABLE:
cls_name = self.__class__.__name__
raise ImportError(f"Pytorch is required to use {cls_name}")
super().__init__(*args, **kwargs)

@staticmethod
def _pad(
sequence: Sequence[torch.Tensor],
sequence: Sequence["torch.Tensor"],
pad_value: Union[int, float],
dim: int = 0,
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
pad_to_multiple_of: Optional[int] = None,
right_pad: bool = True,
) -> torch.Tensor:
) -> "torch.Tensor":
"""Pad a sequence of tensors to the same length.
Args:
Expand Down Expand Up @@ -272,8 +281,8 @@ def _pad(
return torch.cat(to_stack, dim=dim)

def transform( # type: ignore
self: "TensorCollatorMapper", data: Dict[str, Sequence[torch.Tensor]]
) -> Dict[str, torch.Tensor]:
self: "TensorCollatorMapper", data: Dict[str, Sequence["torch.Tensor"]]
) -> Dict[str, "torch.Tensor"]:
collated_data = {
field_name: self._pad(
sequence=list_of_tensors,
Expand Down
23 changes: 16 additions & 7 deletions src/smashed/mappers/converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, TypeVar, Union

import torch
from necessary import necessary
from trouting import trouting

Expand All @@ -15,16 +14,22 @@
"HuggingFaceDataset", Dataset, IterableDataset
)

with necessary("torch", soft=True) as PYTORCH_AVAILABLE:
if PYTORCH_AVAILABLE or TYPE_CHECKING:
import torch


class Python2TorchMapper(SingleBaseMapper):
__slots__ = ["field_cast_map", "device"]
field_cast_map: Dict[str, torch.dtype]
device: Union[torch.device, None]
field_cast_map: Dict[str, "torch.dtype"]
device: Union["torch.device", None]

def __init__(
self: "Python2TorchMapper",
field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None,
device: Optional[Union[torch.device, str]] = None,
field_cast_map: Optional[
Mapping[str, Union[str, "torch.dtype"]]
] = None,
device: Optional[Union["torch.device", str]] = None,
) -> None:
"""Mapper that converts Python types to Torch types. It can optionally
cast the values of a field to a specific type, and move to a specific
Expand All @@ -37,6 +42,10 @@ def __init__(
device (Union[torch.device, str], optional): Device to move the
tensors to. Defaults to None, which means no moving occurs.
"""
if not PYTORCH_AVAILABLE:
cls_name = self.__class__.__name__
raise ImportError(f"{cls_name} requires PyTorch to be installed")

self.device = torch.device(device) if device else None

self.field_cast_map = {
Expand All @@ -49,7 +58,7 @@ def __init__(
)

@staticmethod
def _get_dtype(dtype: Any) -> torch.dtype:
def _get_dtype(dtype: Any) -> "torch.dtype":
if isinstance(dtype, str):
dtype = getattr(torch, dtype, None)
if dtype is None:
Expand Down Expand Up @@ -102,7 +111,7 @@ def __init__(self: "Torch2PythonMapper") -> None:
super().__init__()

def transform( # type: ignore
self: "Torch2PythonMapper", data: Dict[str, torch.Tensor]
self: "Torch2PythonMapper", data: Dict[str, "torch.Tensor"]
) -> TransformElementType:
return {
field_name: field_value.cpu().tolist()
Expand Down
5 changes: 2 additions & 3 deletions src/smashed/mappers/fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar

from necessary import necessary
from torch._utils import classproperty

from ..base import SingleBaseMapper, TransformElementType

Expand All @@ -19,7 +18,7 @@ class ChangeFieldsMapper(SingleBaseMapper):
"""Mapper that removes some of the fields in a dataset.
Either `keep_fields` or `drop_fields` must be specified, but not both."""

@classproperty
@classmethod
def always_remove_columns(cls) -> bool:
return True

Expand Down Expand Up @@ -71,7 +70,7 @@ def transform(self, data: TransformElementType) -> TransformElementType:
class RenameFieldsMapper(SingleBaseMapper):
"""Mapper that renames some of the fields batch"""

@classproperty
@classmethod
def always_remove_columns(cls) -> bool:
return True

Expand Down
2 changes: 1 addition & 1 deletion src/smashed/mappers/glom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ExtendGlommerMixin:

def __getstate__(self):
state = super().__getstate__() # pyright: ignore
state["__dict__"].pop("glommer", None)
state["__dict__"].pop("glommer", None) # pyright: ignore
return state

@cached_property
Expand Down
1 change: 0 additions & 1 deletion src/smashed/utils/io_utils/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def compress_stream(
errors: str = "strict",
gzip: bool = True,
) -> Iterator[IO]:

assert gzip, "Only gzip compression is supported at this time"

if mode == "wb" or mode == "w":
Expand Down

0 comments on commit 376f6b3

Please sign in to comment.