Skip to content

Commit

Permalink
Fix MIND-SSC (#77)
Browse files Browse the repository at this point in the history
* Add squared lncc and mind-ssc losses

* fix cpu error and add indexing parameter for meshgrid in mind-ssc

---------

Co-authored-by: Basar Demir <[email protected]>
  • Loading branch information
basardemir and Basar Demir authored Oct 29, 2024
1 parent c926e9a commit 2a6e26e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/icon_registration/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,15 +771,15 @@ def compute_mindssc(self, img):
dist = self.pdist_squared(six_neighbourhood.t().unsqueeze(0)).squeeze(0)

# define comparison mask
x, y = torch.meshgrid(torch.arange(6), torch.arange(6))
x, y = torch.meshgrid(torch.arange(6), torch.arange(6), indexing='ij')
mask = ((x > y).view(-1) & (dist == 2).view(-1))

# build kernel
idx_shift1 = six_neighbourhood.unsqueeze(1).repeat(1, 6, 1).view(-1, 3)[mask, :]
idx_shift2 = six_neighbourhood.unsqueeze(0).repeat(6, 1, 1).view(-1, 3)[mask, :]
mshift1 = torch.zeros(12, 1, 3, 3, 3).cuda()
mshift1 = torch.zeros(12, 1, 3, 3, 3).to(img.device)
mshift1.view(-1)[torch.arange(12) * 27 + idx_shift1[:, 0] * 9 + idx_shift1[:, 1] * 3 + idx_shift1[:, 2]] = 1
mshift2 = torch.zeros(12, 1, 3, 3, 3).cuda()
mshift2 = torch.zeros(12, 1, 3, 3, 3).to(img.device)
mshift2.view(-1)[torch.arange(12) * 27 + idx_shift2[:, 0] * 9 + idx_shift2[:, 1] * 3 + idx_shift2[:, 2]] = 1
rpad1 = torch.nn.ReplicationPad3d(self.dilation)
rpad2 = torch.nn.ReplicationPad3d(self.radius)
Expand Down

0 comments on commit 2a6e26e

Please sign in to comment.