diff --git a/vae/README.md b/vae/README.md index cda6a33672..e2a432fd1e 100644 --- a/vae/README.md +++ b/vae/README.md @@ -14,8 +14,9 @@ The main.py script accepts the following arguments: optional arguments: --batch-size input batch size for training (default: 128) --epochs number of epochs to train (default: 10) - --no-cuda enables CUDA training - --mps enables GPU on macOS + --no-cuda disables CUDA training + --no-mps disables GPU on macOS + --no-xpu disables XPU training in Intel GPUs --seed random seed (default: 1) --log-interval how many batches to wait before logging training status -``` \ No newline at end of file +``` diff --git a/vae/main.py b/vae/main.py index d69833fbe0..f7915b9ced 100644 --- a/vae/main.py +++ b/vae/main.py @@ -13,10 +13,12 @@ help='input batch size for training (default: 128)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') -parser.add_argument('--no-cuda', action='store_true', default=False, +parser.add_argument('--no-cuda', action='store_true', help='disables CUDA training') -parser.add_argument('--no-mps', action='store_true', default=False, +parser.add_argument('--no-mps', action='store_true', help='disables macOS GPU training') +parser.add_argument('--no-xpu', action='store_true', + help='disables Intel XPU training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', @@ -24,6 +26,7 @@ args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() +use_xpu = not args.no_xpu and torch.xpu.is_available() torch.manual_seed(args.seed) @@ -31,9 +34,13 @@ device = torch.device("cuda") elif use_mps: device = torch.device("mps") +elif use_xpu: + device = torch.device("xpu") else: device = torch.device("cpu") +print('Device to use: ', device) + kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True,