From 61f67ab3861af3fc4b506ec25663bb56a6dd7803 Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Tue, 18 Apr 2023 00:08:19 -0700 Subject: [PATCH] Update multi-gpu notebook to set cupy device (#675) * Update multi-gpu notebook to set cupy device * dummy commit to re-trigger pipeline --- .../03-Session-based-Yoochoose-multigpu-training-PyT.ipynb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb b/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb index 488e18bb0a..8308071468 100644 --- a/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb +++ b/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb @@ -112,21 +112,24 @@ "source": [ "%%writefile './pyt_trainer.py'\n", "\n", + "import argparse\n", "import os\n", "import glob\n", "import torch \n", "\n", + "import cupy\n", + "\n", "from transformers4rec import torch as tr\n", "from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt\n", "from transformers4rec.torch.utils.examples_utils import wipe_memory\n", "from merlin.schema import Schema\n", "from merlin.io import Dataset\n", "\n", - "import argparse\n", + "\n", + "cupy.cuda.Device(int(os.environ[\"LOCAL_RANK\"])).use()\n", "\n", "# define arguments that can be passed to this python script\n", "parser = argparse.ArgumentParser(description='Hyperparameters for model training')\n", - "parser.add_argument('--local_rank', type=int, default=0)\n", "parser.add_argument('--path', type=str, help='Directory with training and validation data')\n", "parser.add_argument('--learning-rate', type=float, default=0.0005, help='Learning rate for training')\n", "parser.add_argument('--per-device-train-batch-size', type=int, default=384, help='Per device batch size for training')\n",