Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain on the fly transforms for PyTorch #2513

Merged
merged 4 commits into from
Oct 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions datasets/doc/source/how-to-use-with-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,37 @@ vary e.g. "img" or "image", "label" or "labels"::

partition.features

In case of CIFAR10, you should see the following output
In case of CIFAR10, you should see the following output.

.. code-block:: none

{'img': Image(decode=True, id=None),
'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
'frog', 'horse', 'ship', 'truck'], id=None)}

Apply Transforms, Create DataLoader. We will use the `map() <https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/main_classes#datasets.Dataset.map>`_
function. Please note that the map will modify the existing dataset if the key in the dictionary you return is already present
and append a new feature if it did not exist before. Below, we modify the "img" feature of our dataset.::

Apply Transforms, Create DataLoader. We will use `Dataset.with_transform() <https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/main_classes#datasets.Dataset.with_transform>`_.
It works on-the-fly, meaning the transforms you specified will be applied only when you access the data, which is also how the transforms work in the PyTorch ecosystem.
The last detail is to know that this function works on the batches of data (even if you select a single element, it is represented as a batch).
That is why we iterate over all the samples from this batch and apply our transforms::

from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

transforms = ToTensor()
def apply_transforms(batch):
batch["img"] = [transforms(img) for img in batch["img"]]
return batch

partition_torch = partition.with_transform(apply_transforms)
# At this point, you can check if you didn't make any mistakes by calling partition_torch[0]
dataloader = DataLoader(partition_torch, batch_size=64)


Alternatively, you can use the `map() <https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/main_classes#datasets.Dataset.map>`_
function. Note that the operation is instant (contrary to the set_transform and with_transform). Remember that the map
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
will modify the existing dataset if the key in the dictionary you return is already present and append a new feature if
it did not exist before. Below, we modify the "img" feature of our dataset.::
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
Expand Down