Skip to content

Commit

Permalink
Update multi-gpu notebook to set cupy device (#675)
Browse files Browse the repository at this point in the history
* Update multi-gpu notebook to set cupy device

* dummy commit to re-trigger pipeline
  • Loading branch information
edknv authored Apr 18, 2023
1 parent 517cb5f commit 61f67ab
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,24 @@
"source": [
"%%writefile './pyt_trainer.py'\n",
"\n",
"import argparse\n",
"import os\n",
"import glob\n",
"import torch \n",
"\n",
"import cupy\n",
"\n",
"from transformers4rec import torch as tr\n",
"from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt\n",
"from transformers4rec.torch.utils.examples_utils import wipe_memory\n",
"from merlin.schema import Schema\n",
"from merlin.io import Dataset\n",
"\n",
"import argparse\n",
"\n",
"cupy.cuda.Device(int(os.environ[\"LOCAL_RANK\"])).use()\n",
"\n",
"# define arguments that can be passed to this python script\n",
"parser = argparse.ArgumentParser(description='Hyperparameters for model training')\n",
"parser.add_argument('--local_rank', type=int, default=0)\n",
"parser.add_argument('--path', type=str, help='Directory with training and validation data')\n",
"parser.add_argument('--learning-rate', type=float, default=0.0005, help='Learning rate for training')\n",
"parser.add_argument('--per-device-train-batch-size', type=int, default=384, help='Per device batch size for training')\n",
Expand Down

0 comments on commit 61f67ab

Please sign in to comment.