diff --git a/src/icon_registration/losses.py b/src/icon_registration/losses.py index f6f5cdb..4327612 100644 --- a/src/icon_registration/losses.py +++ b/src/icon_registration/losses.py @@ -771,15 +771,15 @@ def compute_mindssc(self, img): dist = self.pdist_squared(six_neighbourhood.t().unsqueeze(0)).squeeze(0) # define comparison mask - x, y = torch.meshgrid(torch.arange(6), torch.arange(6)) + x, y = torch.meshgrid(torch.arange(6), torch.arange(6), indexing='ij') mask = ((x > y).view(-1) & (dist == 2).view(-1)) # build kernel idx_shift1 = six_neighbourhood.unsqueeze(1).repeat(1, 6, 1).view(-1, 3)[mask, :] idx_shift2 = six_neighbourhood.unsqueeze(0).repeat(6, 1, 1).view(-1, 3)[mask, :] - mshift1 = torch.zeros(12, 1, 3, 3, 3).cuda() + mshift1 = torch.zeros(12, 1, 3, 3, 3).to(img.device) mshift1.view(-1)[torch.arange(12) * 27 + idx_shift1[:, 0] * 9 + idx_shift1[:, 1] * 3 + idx_shift1[:, 2]] = 1 - mshift2 = torch.zeros(12, 1, 3, 3, 3).cuda() + mshift2 = torch.zeros(12, 1, 3, 3, 3).to(img.device) mshift2.view(-1)[torch.arange(12) * 27 + idx_shift2[:, 0] * 9 + idx_shift2[:, 1] * 3 + idx_shift2[:, 2]] = 1 rpad1 = torch.nn.ReplicationPad3d(self.dilation) rpad2 = torch.nn.ReplicationPad3d(self.radius)