From 38644851f53dda73caa77421e14503aaccac4dbc Mon Sep 17 00:00:00 2001 From: Hugo Sadok Date: Tue, 21 May 2024 11:12:48 -0400 Subject: [PATCH] Add support for MPS backend (Apple Metal) --- main.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index fec4778..469619f 100644 --- a/main.py +++ b/main.py @@ -27,8 +27,8 @@ def train(args, train_loader, model, criterion, optimizer, writer): loss_epoch = 0 for step, ((x_i, x_j), _) in enumerate(train_loader): optimizer.zero_grad() - x_i = x_i.cuda(non_blocking=True) - x_j = x_j.cuda(non_blocking=True) + x_i = x_i.to(args.device, non_blocking=True) + x_j = x_j.to(args.device, non_blocking=True) # positive pair, with encoding h_i, h_j, z_i, z_j = model(x_i, x_j) @@ -170,8 +170,16 @@ def main(gpu, args): if not os.path.exists(args.model_path): os.makedirs(args.model_path) - args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - args.num_gpus = torch.cuda.device_count() + if torch.cuda.is_available(): + args.device = torch.device("cuda") + args.num_gpus = torch.cuda.device_count() + elif torch.backends.mps.is_available(): + args.device = torch.device("mps") + args.num_gpus = 1 + else: + args.device = torch.device("cpu") + args.num_gpus = 0 + args.world_size = args.gpus * args.nodes if args.nodes > 1: