From 09b46884facdecbd43ffc7d8a7717e1c8f311a01 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Wed, 5 Jul 2023 12:52:00 +1000 Subject: [PATCH 1/5] * rename arg dim -> embedding_dim * fix docstring * add test and cleanup * add example --- examples/pytorch/01-Getting-started.ipynb | 364 ++++++++++++++++++++++ merlin/models/torch/models/ranking.py | 10 +- tests/unit/torch/models/test_ranking.py | 2 +- 3 files changed, 370 insertions(+), 6 deletions(-) create mode 100644 examples/pytorch/01-Getting-started.ipynb diff --git a/examples/pytorch/01-Getting-started.ipynb b/examples/pytorch/01-Getting-started.ipynb new file mode 100644 index 0000000000..3eb47a0c4d --- /dev/null +++ b/examples/pytorch/01-Getting-started.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bb28e271", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2023 NVIDIA Corporation. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# ==============================================================================\n", + "\n", + "# Each user is responsible for checking the content of datasets and the\n", + "# applicable licenses and determining if suitable for the intended use." + ] + }, + { + "cell_type": "markdown", + "id": "23d9bf34", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Getting Started with Merlin Models: Develop a Model for MovieLens using the PyTorch API\n", + "\n", + "This notebook is created using the latest stable [merlin-pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-pytorch/tags) container. \n", + "\n", + "## Overview\n", + "\n", + "[Merlin Models](https://github.com/NVIDIA-Merlin/models/) is a library for training recommender models. Merlin Models let Data Scientists and ML Engineers easily train standard RecSys models on their own dataset, getting GPU-accelerated models with best practices baked into the library. This will also let researchers to build custom models by incorporating standard components of deep learning recommender models, and then benchmark their new models on example offline datasets. Merlin Models is part of the [Merlin open source framework](https://developer.nvidia.com/nvidia-merlin).\n", + "\n", + "Core features are:\n", + "- Many different recommender system architectures (tabular, two-tower, sequential) or tasks (binary, multi-class classification, multi-task)\n", + "- Flexible APIs targeted to both production and research\n", + "- Deep integration with NVIDIA Merlin platform, including NVTabular for ETL and Merlin Systems model serving\n", + "\n", + "\n", + "### Learning objectives\n", + "\n", + "- Training [Facebook's DLRM model](https://arxiv.org/pdf/1906.00091.pdf) very easily with our high-level API.\n", + "- Understanding Merlin Models high-level API" + ] + }, + { + "cell_type": "markdown", + "id": "1c5598ae", + "metadata": {}, + "source": [ + "## Downloading and preparing the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60653f70", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import merlin.models.torch as mm\n", + "from merlin.loader.torch import Loader\n", + "import pytorch_lightning as pl\n", + "\n", + "from merlin.datasets.entertainment import get_movielens" + ] + }, + { + "cell_type": "markdown", + "id": "5327924b", + "metadata": {}, + "source": [ + "We provide the `get_movielens()` function as a convenience to download the dataset, perform simple preprocessing, and split the data into training and validation datasets." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9ba8b53d", + "metadata": {}, + "outputs": [], + "source": [ + "input_path = os.environ.get(\"INPUT_DATA_DIR\", os.path.expanduser(\"~/merlin-models-data/movielens/\"))\n", + "train, valid = get_movielens(variant=\"ml-1m\", path=input_path)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee5c7c2", + "metadata": {}, + "source": [ + "## Training the DLRM Model with Merlin Models" + ] + }, + { + "cell_type": "markdown", + "id": "688b89c7", + "metadata": {}, + "source": [ + "We define the DLRM model, whose prediction task is a binary classification. From the `schema`, the categorical features are identified (and embedded) and the target columns are also automatically inferred, because of the schema tags. We talk more about the schema in the next [example notebook (02)](02-Merlin-Models-and-NVTabular-integration.ipynb)," + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d3b8942c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n", + " warnings.warn('Lazy modules are a new feature under heavy development '\n" + ] + } + ], + "source": [ + "model = mm.DLRMModel(\n", + " train.schema,\n", + " embedding_dim=64,\n", + " bottom_block=mm.MLPBlock([128, 64]),\n", + " top_block=mm.MLPBlock([128, 64, 32]),\n", + " output_block=mm.BinaryOutput(train.schema.select_by_name('rating_binary')),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "64ee4cef", + "metadata": {}, + "source": [ + "Next, we train the model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "33343067", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + " warning_cache.warn(\n", + "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py:70: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + " rank_zero_warn(\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + "0 | values | ModuleList | 1.1 M \n", + "--------------------------------------\n", + "1.1 M Trainable params\n", + "0 Non-trainable params\n", + "1.1 M Total params\n", + "4.459 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 100.68it/s, v_num=7]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=1` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 100.35it/s, v_num=7]\n" + ] + } + ], + "source": [ + "trainer = pl.Trainer(max_epochs=1)\n", + "\n", + "with Loader(train, batch_size=1024) as loader:\n", + " model.initialize(loader)\n", + " trainer.fit(model, loader)" + ] + }, + { + "cell_type": "markdown", + "id": "4bd668ab", + "metadata": {}, + "source": [ + "We evaluate the model and check the evaluation metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "34f01ce5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation DataLoader 0: 7%|██████████████████████▍ | 14/196 [00:00<00:01, 157.25it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1024. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", + " warning_cache.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:01<00:00, 164.10it/s]\n", + "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", + " Validate metric DataLoader 0\n", + "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", + " valid_binary_accuracy 0.7193925380706787\n", + " valid_binary_auroc 0.7811704874038696\n", + " valid_binary_precision 0.7327381372451782\n", + " valid_binary_recall 0.8074849843978882\n", + " valid_loss 0.5524654984474182\n", + "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 361. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", + " warning_cache.warn(\n" + ] + } + ], + "source": [ + "with Loader(valid, batch_size=1024) as loader:\n", + " metrics = trainer.validate(model, loader)" + ] + }, + { + "cell_type": "markdown", + "id": "2a6ad327", + "metadata": {}, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "id": "eeba861b", + "metadata": {}, + "source": [ + "Merlin Models enables users to define and train a deep learning recommeder model with just a handful of commands.\n", + "\n", + "```python\n", + "model = mm.DLRMModel(\n", + " train.schema,\n", + " embedding_dim=64,\n", + " bottom_block=mm.MLPBlock([128, 64]),\n", + " top_block=mm.MLPBlock([128, 64, 32]),\n", + " output_block=mm.BinaryOutput(train.schema.select_by_name('rating_binary')),\n", + ")\n", + "\n", + "trainer = pl.Trainer(max_epochs=1)\n", + "\n", + "with Loader(train, batch_size=1024) as loader:\n", + " model.initialize(loader)\n", + " trainer.fit(model, loader)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "c9180e84", + "metadata": {}, + "source": [ + "## Next steps" + ] + }, + { + "cell_type": "markdown", + "id": "b9ba3102", + "metadata": {}, + "source": [ + "In the next example notebooks, we will show how the integration with NVTabular and how to explore different recommender models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "merlin": { + "containers": [ + "nvcr.io/nvidia/merlin/merlin-tensorflow:latest" + ] + }, + "vscode": { + "interpreter": { + "hash": "ab403bb43341787581f43b51cdd291d61392c89ddb0f92179de653921d4e05db" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/merlin/models/torch/models/ranking.py b/merlin/models/torch/models/ranking.py index 292abebbd8..b34c56403f 100644 --- a/merlin/models/torch/models/ranking.py +++ b/merlin/models/torch/models/ranking.py @@ -17,8 +17,8 @@ class DLRMModel(Model): ---------- schema : Schema The schema to use for selection. - dim : int - The dimensionality of the output vectors. + embedding_dim : int + The dimensionality of the embedding vectors for CONTINUOUS and CATEGORICAL features. bottom_block : Block Block to pass the continuous features to. Note that, the output dimensionality of this block must be equal to ``dim``. @@ -40,7 +40,7 @@ class DLRMModel(Model): ------------- >>> model = mm.DLRMModel( ... schema, - ... dim=64, + ... embedding_dim=64, ... bottom_block=mm.MLPBlock([256, 64]), ... output_block=BinaryOutput(ColumnSchema("target"))) >>> trainer = pl.Trainer() @@ -56,7 +56,7 @@ class DLRMModel(Model): def __init__( self, schema: Schema, - dim: int, + embedding_dim: int, bottom_block: Block, top_block: Optional[Block] = None, interaction: Optional[nn.Module] = None, @@ -67,7 +67,7 @@ def __init__( dlrm_body = DLRMBlock( schema, - dim, + embedding_dim, bottom_block, top_block=top_block, interaction=interaction, diff --git a/tests/unit/torch/models/test_ranking.py b/tests/unit/torch/models/test_ranking.py index 0fb463e0ef..d43d47938d 100644 --- a/tests/unit/torch/models/test_ranking.py +++ b/tests/unit/torch/models/test_ranking.py @@ -20,7 +20,7 @@ def test_train_dlrm_with_lightning_loader( model = mm.DLRMModel( schema, - dim=dim, + embedding_dim=dim, bottom_block=mm.MLPBlock([4, 2]), top_block=mm.MLPBlock([4, 2]), output_block=output_block, From b54569bfef82215094a7a9944188c672c0ac2d6f Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Wed, 5 Jul 2023 22:26:18 +1000 Subject: [PATCH 2/5] update logo --- examples/pytorch/01-Getting-started.ipynb | 52 ++++++++--------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/examples/pytorch/01-Getting-started.ipynb b/examples/pytorch/01-Getting-started.ipynb index 3eb47a0c4d..3df3eebcd6 100644 --- a/examples/pytorch/01-Getting-started.ipynb +++ b/examples/pytorch/01-Getting-started.ipynb @@ -31,7 +31,7 @@ "id": "23d9bf34", "metadata": {}, "source": [ - "\n", + "\n", "\n", "# Getting Started with Merlin Models: Develop a Model for MovieLens using the PyTorch API\n", "\n", @@ -189,7 +189,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 100.68it/s, v_num=7]" + "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 101.54it/s, v_num=10]" ] }, { @@ -203,16 +203,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 100.35it/s, v_num=7]\n" + "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 101.21it/s, v_num=10]\n" ] } ], "source": [ "trainer = pl.Trainer(max_epochs=1)\n", + "train_loader = Loader(train, batch_size=1024)\n", "\n", - "with Loader(train, batch_size=1024) as loader:\n", - " model.initialize(loader)\n", - " trainer.fit(model, loader)" + "model.initialize(train_loader)\n", + "trainer.fit(model, train_loader)" ] }, { @@ -240,7 +240,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Validation DataLoader 0: 7%|██████████████████████▍ | 14/196 [00:00<00:01, 157.25it/s]" + "Validation DataLoader 0: 7%|██████▌ | 13/196 [00:00<00:01, 161.86it/s]" ] }, { @@ -255,15 +255,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:01<00:00, 164.10it/s]\n", + "Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:01<00:00, 165.71it/s]\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", " Validate metric DataLoader 0\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", - " valid_binary_accuracy 0.7193925380706787\n", - " valid_binary_auroc 0.7811704874038696\n", - " valid_binary_precision 0.7327381372451782\n", - " valid_binary_recall 0.8074849843978882\n", - " valid_loss 0.5524654984474182\n", + " val_binary_accuracy 0.7193425297737122\n", + " val_binary_auroc 0.7803523540496826\n", + " val_binary_precision 0.7274115681648254\n", + " val_binary_recall 0.8201844692230225\n", + " val_loss 0.5525734424591064\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n" ] }, @@ -277,8 +277,8 @@ } ], "source": [ - "with Loader(valid, batch_size=1024) as loader:\n", - " metrics = trainer.validate(model, loader)" + "val_loader = Loader(valid, batch_size=1024)\n", + "metrics = trainer.validate(model, val_loader)" ] }, { @@ -306,28 +306,12 @@ ")\n", "\n", "trainer = pl.Trainer(max_epochs=1)\n", + "train_loader = Loader(train, batch_size=1024)\n", "\n", - "with Loader(train, batch_size=1024) as loader:\n", - " model.initialize(loader)\n", - " trainer.fit(model, loader)\n", + "model.initialize(train_loader)\n", + "trainer.fit(model, train_loader)\n", "```" ] - }, - { - "cell_type": "markdown", - "id": "c9180e84", - "metadata": {}, - "source": [ - "## Next steps" - ] - }, - { - "cell_type": "markdown", - "id": "b9ba3102", - "metadata": {}, - "source": [ - "In the next example notebooks, we will show how the integration with NVTabular and how to explore different recommender models." - ] } ], "metadata": { From 7de630de271a29757430bb271fefae9cb4b9af06 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Wed, 5 Jul 2023 22:38:38 +1000 Subject: [PATCH 3/5] add test --- .../torch/examples/test_01_getting_started.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/unit/torch/examples/test_01_getting_started.py diff --git a/tests/unit/torch/examples/test_01_getting_started.py b/tests/unit/torch/examples/test_01_getting_started.py new file mode 100644 index 0000000000..d0ae8dc5f9 --- /dev/null +++ b/tests/unit/torch/examples/test_01_getting_started.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from testbook import testbook + +from tests.conftest import REPO_ROOT + + +@testbook(REPO_ROOT / "examples/pytorch/01-Getting-started.ipynb", execute=False) +@pytest.mark.notebook +def test_example_01_getting_started(tb): + tb.inject( + """ + from unittest.mock import patch + from merlin.datasets.synthetic import generate_data + mock_train, mock_valid = generate_data( + input="movielens-1m", + num_rows=1000, + set_sizes=(0.8, 0.2) + ) + p1 = patch( + "merlin.datasets.entertainment.get_movielens", + return_value=[mock_train, mock_valid] + ) + p1.start() + """ + ) + tb.execute() + metrics = tb.ref("metrics") + assert set(metrics[0].keys()) == set( + [ + "val_loss", + "val_binary_accuracy", + "val_binary_auroc", + "val_binary_precision", + "val_binary_recall" + ] + ) From 73689704c9dec2c228d520aa609a695554c9939d Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Wed, 5 Jul 2023 22:49:07 +1000 Subject: [PATCH 4/5] appease linter --- tests/unit/torch/examples/test_01_getting_started.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/torch/examples/test_01_getting_started.py b/tests/unit/torch/examples/test_01_getting_started.py index d0ae8dc5f9..2dcaea452a 100644 --- a/tests/unit/torch/examples/test_01_getting_started.py +++ b/tests/unit/torch/examples/test_01_getting_started.py @@ -46,6 +46,6 @@ def test_example_01_getting_started(tb): "val_binary_accuracy", "val_binary_auroc", "val_binary_precision", - "val_binary_recall" + "val_binary_recall", ] ) From b1d04bb3c1f926bcdec88a28ba6529f10bba4a92 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Thu, 6 Jul 2023 18:34:05 +1000 Subject: [PATCH 5/5] add information regarding `initialize` --- examples/pytorch/01-Getting-started.ipynb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/pytorch/01-Getting-started.ipynb b/examples/pytorch/01-Getting-started.ipynb index 3df3eebcd6..a060f8bdd1 100644 --- a/examples/pytorch/01-Getting-started.ipynb +++ b/examples/pytorch/01-Getting-started.ipynb @@ -211,6 +211,8 @@ "trainer = pl.Trainer(max_epochs=1)\n", "train_loader = Loader(train, batch_size=1024)\n", "\n", + "# The initialize step ensures the model and data are on the correct device\n", + "# and prepares the model for training\n", "model.initialize(train_loader)\n", "trainer.fit(model, train_loader)" ]