diff --git a/Spectra/Spectra_gpu.py b/Spectra/Spectra_gpu.py index 8ba8665..cd48ad4 100644 --- a/Spectra/Spectra_gpu.py +++ b/Spectra/Spectra_gpu.py @@ -259,9 +259,9 @@ def __init__(self, X, labels, adj_matrix, L, weights = None, lam = 10e-4, delta= lst_weights.append(torch.Tensor(weights[cell_type]) - torch.Tensor(np.diag(np.diag(weights[cell_type]))) ) else: lst_weights.append(torch.zeros((self.p, self.p))) + self.weights = torch.stack(lst_weights).to(device) else: - self.weights = self.adj_matrix - self.weights = torch.stack(lst_weights).to(device) + self.weights = self.adj_matrix self.ct_order = ct_order self.L_tot = L_tot self.L_list = L_list