diff --git a/deepAccNet/model.py b/deepAccNet/model.py index a752adb..d8a4451 100644 --- a/deepAccNet/model.py +++ b/deepAccNet/model.py @@ -124,8 +124,8 @@ def calculate_LDDT(estogram, mask, center=7): nres = mask.shape[-1] diags = torch.ones((nres, nres)).to(device) - torch.eye(nres).to(device) for i in range(1,5): - diags = diags - torch.diag(torch.ones(nres-i), diagonal=i) - diags = diags - torch.diag(torch.ones(nres-i), diagonal=-1*i) + diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=i).to(device) + diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=-1*i).to(device) mask = torch.mul(mask, diags) masked = torch.mul(estogram, mask) @@ -167,4 +167,4 @@ def scatter_nd(indices, updates, shape): out = out.scatter_add(0, flattened_indices, updates) # Reshape - return out.view(shape) \ No newline at end of file + return out.view(shape) diff --git a/deepAccNet/model2.py b/deepAccNet/model2.py index 50e633c..3a8862e 100644 --- a/deepAccNet/model2.py +++ b/deepAccNet/model2.py @@ -293,8 +293,8 @@ def calculate_LDDT(estogram, mask, center=7): nres = mask.shape[-1] diags = torch.ones((nres, nres)).to(device) - torch.eye(nres).to(device) for i in range(1,5): - diags = diags - torch.diag(torch.ones(nres-i), diagonal=i) - diags = diags - torch.diag(torch.ones(nres-i), diagonal=-1*i) + diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=i).to(device) + diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=-1*i).to(device) masked = torch.mul(estogram, mask) p0 = (masked[center]).sum(axis=0) @@ -335,4 +335,4 @@ def scatter_nd(indices, updates, shape): out = out.scatter_add(0, flattened_indices, updates) # Reshape - return out.view(shape) \ No newline at end of file + return out.view(shape)