Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function to find NMF programs (searching K) #298

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pegasus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
tsvd_transform,
regress_out,
nmf,
find_nmf_programs,
integrative_nmf,
highly_variable_features,
run_harmony,
Expand Down
10 changes: 7 additions & 3 deletions pegasus/plotting/plot_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,8 @@ def plot_dendrogram(

from scipy.cluster.hierarchy import dendrogram

Z, labels = data.uns[graph_key]
Z = data.uns[graph_key][0]
labels = data.uns[graph_key][1].index
fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi)
dendrogram(
Z,
Expand Down Expand Up @@ -2305,6 +2306,7 @@ def wordcloud(
data: Union[MultimodalData, UnimodalData, anndata.AnnData],
factor: int,
max_words: Optional[int] = 20,
features: Optional[str] = "highly_variable_features",
random_state: Optional[int] = 0,
colormap: Optional[str] = "hsv",
width: Optional[int] = 800,
Expand All @@ -2325,6 +2327,8 @@ def wordcloud(
Which factor to plot. factor starts from 0.
max_words: ``int``, optional, default: 20
Maximum number of genes to show in the image.
features: ``str``, optional, default: ``highly_variable_features``
Features selected for NMF computation.
random_state: ``int``, optional, default: 0
Random seed passing to WordCloud function.
colormap: ``str``, optional, default: ``hsv``
Expand All @@ -2351,9 +2355,9 @@ def wordcloud(
>>> fig = pg.wordcloud(data, factor=0)
"""
fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi) # default nrows = 1 & ncols = 1

assert 'W' in data.uns
hvg = data.var_names[data.var['highly_variable_features']]
hvg = data.var_names[data.var[features]]
word_dict = {}
for i in range(hvg.size):
word_dict[hvg[i]] = data.uns['W'][i, factor]
Expand Down
6 changes: 4 additions & 2 deletions pegasus/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,11 @@ def _get_palette(n_labels: int, with_background: bool = False, show_background:
palette = pegasus_20
elif n_labels <= 26:
palette = zeileis_26
else:
assert n_labels <= 64
elif n_labels <= 64:
palette = godsnot_64
else:
n_rep = (n_labels - 1) // 20 + 1 # a cyclic color panel
palette = pegasus_20 * n_rep

if with_background:
palette = np.array(
Expand Down
2 changes: 1 addition & 1 deletion pegasus/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from .subcluster_utils import clone_subset
from .signature_score import calc_signature_score, calculate_z_score
from .doublet_detection import infer_doublets, mark_doublets
from .nmf import nmf, integrative_nmf
from .nmf import nmf, integrative_nmf, find_nmf_programs
from .pseudobulk import pseudobulk, deseq2
from .fgsea import fgsea, write_fgsea_results_to_excel
from .scvitools import (
Expand Down
2 changes: 1 addition & 1 deletion pegasus/tools/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,4 +825,4 @@ def calc_dendrogram(
np.fill_diagonal(dissim_df.to_numpy(), 0) # Enforce main diagonal to be 0 to pass squareform requirement
Z = linkage(squareform(dissim_df), method=linkage_method, optimal_ordering=True)

data.uns[res_key] = (Z, dissim_df.index.values.astype(str))
data.uns[res_key] = (Z, csi_df)
3 changes: 2 additions & 1 deletion pegasus/tools/hvf_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def select_hvf_pegasus(
""" Select highly variable features using the pegasus method
"""
if "robust" not in data.var:
raise ValueError("Please run `identify_robust_genes` to identify robust genes")
logger.warning("Robust genes are not identified. Mark all genes as robust.")
data.var["robust"] = True

estimate_feature_statistics(data, batch)

Expand Down
164 changes: 163 additions & 1 deletion pegasus/tools/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numba import njit
from numba.typed import List as numbaList

from typing import List, Union
from typing import List, Union, Tuple
from pegasusio import UnimodalData, MultimodalData
from pegasus.tools import slicing, eff_n_jobs, calculate_nearest_neighbors, check_batch_key

Expand Down Expand Up @@ -205,6 +205,168 @@ def nmf(
data.obsm["X_nmf"] = H / np.linalg.norm(H, axis=0)


@timer(logger=logger)
def find_nmf_programs(
data: Union[MultimodalData, UnimodalData],
n_range: Tuple[int, int] = (4, 9),
n_rep: int = 10,
features: str = "highly_variable_features",
space: str = "log",
init: str = "random",
algo: str = "halsvar",
mode: str = "batch",
tol: float = 1e-4,
use_gpu: bool = False,
alpha_W: float = 0.0,
l1_ratio_W: float = 0.0,
alpha_H: float = 0.01,
l1_ratio_H: float = 1.0,
fp_precision: str = "float",
online_chunk_size: int = 5000,
n_jobs: int = -1,
random_state: int = 0,
) -> Tuple[list, list, list, list]:
"""Perform Nonnegative Matrix Factorization (NMF) to the data using Frobenius norm. Steps include select features and L2 normalization and NMF and L2 normalization of resulting coordinates.

The calculation uses `nmf-torch <https://github.com/lilab-bcb/nmf-torch>`_ package.

Parameters
----------
data: ``pegasusio.MultimodalData``
Annotated data matrix with rows for cells and columns for genes.

n_range: ``Tuple[int, int]``, optional, default: ``(4, 9)``.
Number of ranks to iterate over.

n_rep: ``int``, optional, default: 10
Number of reruns for each value in n_range.

features: ``str``, optional, default: ``"highly_variable_features"``.
Keyword in ``data.var`` to specify features used for nmf.

max_value: ``float``, optional, default: ``None``.
The threshold to truncate data symmetrically after scaling. If ``None``, do not truncate.

space: ``str``, optional, default: ``log``.
Choose from ``log`` and ``expression``. ``log`` works on log-transformed expression space; ``expression`` works on the original expression space (normalized by total UMIs).

init: ``str``, optional, default: ``random``.
Method to initialize NMF. Options are 'random', 'nndsvd', 'nndsvda' and 'nndsvdar'.

algo: ``str``, optional, default: ``halsvar``
Choose from ``mu`` (Multiplicative Update), ``hals`` (Hierarchical Alternative Least Square), ``halsvar`` (HALS variant, use HALS to mimic ``bpp`` and can get better convergence for sometimes) and ``bpp`` (alternative non-negative least squares with Block Principal Pivoting method).

mode: ``str``, optional, default: ``batch``
Learning mode. Choose from ``batch`` and ``online``. Notice that ``online`` only works when ``beta=2.0``. For other beta loss, it switches back to ``batch`` method.

tol: ``float``, optional, default: ``1e-4``
The toleration used for convergence check.

use_gpu: ``bool``, optional, default: ``False``
If ``True``, use GPU if available. Otherwise, use CPU only.

alpha_W: ``float``, optional, default: ``0.0``
A numeric scale factor which multiplies the regularization terms related to W.
If zero or negative, no regularization regarding W is considered.

l1_ratio_W: ``float``, optional, default: ``0.0``
The ratio of L1 penalty on W, must be between 0 and 1. And thus the ratio of L2 penalty on W is (1 - l1_ratio_W).

alpha_H: ``float``, optional, default: ``0.01``
A numeric scale factor which multiplies the regularization terms related to H.
If zero or negative, no regularization regarding H is considered.

l1_ratio_H: ``float``, optional, default: ``1.0``
The ratio of L1 penalty on W, must be between 0 and 1. And thus the ratio of L2 penalty on H is (1 - l1_ratio_H).

fp_precision: ``str``, optional, default: ``float``
The numeric precision on the results. Choose from ``float`` and ``double``.

online_chunk_size: ``int``, optional, default: ``int``
The chunk / mini-batch size for online learning. Only works when ``mode='online'``.

n_jobs : `int`, optional (default: -1)
Number of threads to use. -1 refers to using all physical CPU cores.

random_state: ``int``, optional, default: ``0``.
Random seed to be set for reproducing result.

Returns
-------
Hs: best H for each k in n_range
Ws: best W for each k in n_range
errs: best err for each k in n_range
coph_corrs: cophenetic correlation coefficients for each k in n_range

Examples
--------
>>> Hs, Ws, errs, coph_corrs = pg.find_nmf_programs(data)
"""
X = _select_and_scale_features(data, features=features, space=space)

try:
from nmf import run_nmf
from scipy.cluster.hierarchy import linkage, cophenet
from scipy.spatial.distance import squareform
except ImportError as e:
import sys
logger.error(f"{e}\nNeed NMF-Torch! Try 'pip install nmf-torch'.")
sys.exit(-1)

Hs = []
Ws = []
errs = []
coph_corrs = []

rng = np.random.default_rng(random_state)
BIG_NUM = 1000000000
mats_conn = np.zeros((n_rep, X.shape[0], X.shape[0])) # connectivity matrices

for k in range(n_range[0], n_range[1] + 1):
print(f"Begin k={k}:")

H_best = W_best = None
err_best = 1e100

for i in range(n_rep):
H, W, err = run_nmf(
X,
n_components=k,
init=init,
algo=algo,
mode=mode,
tol=tol,
n_jobs=eff_n_jobs(n_jobs),
random_state=rng.integers(BIG_NUM),
use_gpu=use_gpu,
alpha_W=alpha_W,
l1_ratio_W=l1_ratio_W,
alpha_H=alpha_H,
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
online_chunk_size=online_chunk_size,
)

if err_best > err:
err_best = err
H_best = H
W_best = W

clusters = H.argmax(axis=1)
mats_conn[i] = clusters.reshape((-1, 1)) == clusters.reshape((1, -1))

consensus = mats_conn.mean(axis=0)
Y = squareform(1.0 - consensus)
Z = linkage(Y, method='average')
coph_corr = cophenet(Z, Y)[0]

Hs.append(H_best)
Ws.append(W_best)
errs.append(err_best)
coph_corrs.append(coph_corr)

return Hs, Ws, errs, coph_corrs


@njit(fastmath=True, cache=True)
def _refine_cluster(clusters, indices, ncluster):
Expand Down
Loading