diff --git a/deepAccNet/model.py b/deepAccNet/model.py
index a752adb..d8a4451 100644
--- a/deepAccNet/model.py
+++ b/deepAccNet/model.py
@@ -124,8 +124,8 @@ def calculate_LDDT(estogram, mask, center=7):
     nres = mask.shape[-1]
     diags = torch.ones((nres, nres)).to(device) - torch.eye(nres).to(device)
     for i in range(1,5):
-        diags = diags - torch.diag(torch.ones(nres-i), diagonal=i)
-        diags = diags - torch.diag(torch.ones(nres-i), diagonal=-1*i)
+        diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=i).to(device)
+        diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=-1*i).to(device)
     mask = torch.mul(mask, diags)
     masked = torch.mul(estogram, mask)
 
@@ -167,4 +167,4 @@ def scatter_nd(indices, updates, shape):
     out = out.scatter_add(0, flattened_indices, updates)
     
     # Reshape
-    return out.view(shape)
\ No newline at end of file
+    return out.view(shape)
diff --git a/deepAccNet/model2.py b/deepAccNet/model2.py
index 50e633c..3a8862e 100644
--- a/deepAccNet/model2.py
+++ b/deepAccNet/model2.py
@@ -293,8 +293,8 @@ def calculate_LDDT(estogram, mask, center=7):
     nres = mask.shape[-1]
     diags = torch.ones((nres, nres)).to(device) - torch.eye(nres).to(device)
     for i in range(1,5):
-        diags = diags - torch.diag(torch.ones(nres-i), diagonal=i)
-        diags = diags - torch.diag(torch.ones(nres-i), diagonal=-1*i)
+        diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=i).to(device)
+        diags = diags - torch.diag(torch.ones(nres-i).to(device), diagonal=-1*i).to(device)
     masked = torch.mul(estogram, mask)
 
     p0 = (masked[center]).sum(axis=0)
@@ -335,4 +335,4 @@ def scatter_nd(indices, updates, shape):
     out = out.scatter_add(0, flattened_indices, updates)
     
     # Reshape
-    return out.view(shape)
\ No newline at end of file
+    return out.view(shape)