Skip to content

Commit

Permalink
Merge pull request #21 from epaaso/master
Browse files Browse the repository at this point in the history
Sparse dot product
  • Loading branch information
melonheader authored Nov 10, 2024
2 parents 0a4baf4 + fe2775e commit 07ffba7
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ikarus/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def calculate_connectivities(
path = Path.cwd() / out_dir / name
path.mkdir(parents=True, exist_ok=True)
save_npz(path / "connectivities_sparse.npz", sparse)
return sparse.todense()
return sparse


def propagate_labels(
Expand Down Expand Up @@ -113,8 +113,8 @@ def propagate_labels(
] = True
final_pred_proba.loc[certainty_info[f"certain{i}"] == False] = 0.000001

final_step_mtx = np.dot(connectivities, final_pred_proba.values)
final_step_mtx = np.divide(final_step_mtx, final_step_mtx.sum(axis=1))
final_step_mtx = connectivities.dot(final_pred_proba.values)
final_step_mtx = np.divide(final_step_mtx, final_step_mtx.sum(axis=1, keepdims=True))
final_pred_proba.loc[:, :] = final_step_mtx

current = final_pred_proba.idxmax(axis=1)
Expand Down Expand Up @@ -393,7 +393,7 @@ def predict(
self.out_dir,
)
else:
connectivities = load_npz(connectivities_path).todense()
connectivities = load_npz(connectivities_path)
if connectivities.shape[0] != adata.shape[0]:
raise IndexError(
f"Shape of connectivities matrix ({connectivities.shape}) does "
Expand Down Expand Up @@ -550,7 +550,7 @@ def cnv_correct(
self.out_dir,
)
else:
connectivities = load_npz(connectivities_path).todense()
connectivities = load_npz(connectivities_path)
if connectivities.shape[0] != adata.shape[0]:
raise IndexError(
f"Shape of connectivities matrix ({connectivities.shape}) does "
Expand Down

0 comments on commit 07ffba7

Please sign in to comment.