Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New version of dendrogram #295

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Cluster Algorithms
cluster
louvain
leiden
calc_dendrogram
split_one_cluster
spectral_louvain
spectral_leiden
Expand Down Expand Up @@ -161,7 +162,7 @@ Plotting
violin
heatmap
dotplot
dendrogram
plot_dendrogram
hvfplot
qcviolin
volcano
Expand Down
8 changes: 8 additions & 0 deletions docs/references.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
References
----------

.. [Bass13] J. I. F. Bass, A. Diallo, J. Nelson, J. M. Soto, C. L. Myers, and A. J. Walhout,
"Using networks to measure similarity between genes: association index selection",
In `Nature methods <https://www.nature.com/articles/nmeth.2728>`_, 2013.

.. [Belkina19] A. C. Belkina, C. O. Ciccolella, R. Anno, R. Halpert, J. Spidlen, and J. E. Snyder-Cappione,
"Automated optimized parameters for T-distributed stochastic neighbor embedding improve visualization and analysis of large datasets",
In `Nature Communications <https://www.nature.com/articles/s41467-019-13055-y>`__, 2019.
Expand Down Expand Up @@ -61,6 +65,10 @@ References
"UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction",
Preprint at `arXiv <https://arxiv.org/abs/1802.03426>`__, 2018.

.. [Suo18] S. Suo, Q. Zhu, A. Saadatpour, L. Fei, G. Guo, and G. Yuan,
"Revealing the critical regulators of cell identity in the mouse cell atlas",
In `Cell reports <https://www.sciencedirect.com/science/article/pii/S2211124718316346>`_, 2018.

.. [Traag19] V. A. Traag, L. Waltman, and N. J. van Eck,
"From Louvain to Leiden: guaranteeing well-connected communities",
In `Scientific Reports <https://www.nature.com/articles/s41598-019-41695-z>`__, 2019.
Expand Down
3 changes: 2 additions & 1 deletion pegasus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
leiden,
spectral_louvain,
spectral_leiden,
calc_dendrogram,
tsne,
umap,
fle,
Expand Down Expand Up @@ -92,7 +93,7 @@
violin,
heatmap,
dotplot,
dendrogram,
plot_dendrogram,
hvfplot,
qcviolin,
volcano,
Expand Down
2 changes: 1 addition & 1 deletion pegasus/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
violin,
heatmap,
dotplot,
dendrogram,
plot_dendrogram,
hvfplot,
qcviolin,
volcano,
Expand Down
75 changes: 73 additions & 2 deletions pegasus/plotting/plot_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,78 @@ def non_zero(g):
return fig if return_fig else None


def dendrogram(
def plot_dendrogram(
data: Union[MultimodalData, UnimodalData, anndata.AnnData],
graph_key: str = "dendrogram",
panel_size: Tuple[float, float] = (10, 6),
label_rotation: float = 45,
label_fontsize: int = 10,
orientation: str = 'top',
color_threshold: Optional[float] = None,
return_fig: bool = False,
dpi: float = 300.0,
**kwargs,
) -> Union[plt.Figure, None]:
"""
Generate a dendrogram on hierarchical clustering result

The metric in use is a Connection Specific Index (CSI) matrix ([Suo18]_, [Bass13]_) built from the correlations between ``groupby`` attribute levels regarding the ``rep`` embedding.

Parameters
----------

data: ``MultimodalData``, ``UnimodalData``, or ``AnnData`` object
Single cell expression data.
graph_key: ``str``, optional, ``"dendrogram"``
Keyword in ``data.uns`` that stores the dendrogram configurations calculated by ``pegasus.calc_dendrogram`` function.
panel_size: ``Tuple[float, float]``, optional, default: ``(10, 6)``
The size (width, height) in inches of figure.
label_rotation: ``float``, optional, default: ``45``
The rotation angle of labels.
label_fontsize: ``int``, optional, default: ``10``
The font size of labels.
orientation: ``str``, optional, default: ``top``
The direction to plot the dendrogram. Available options are: ``top``, ``bottom``, ``left``, ``right``. See `scipy dendrogram documentation`_ for explanation.
color_threshold``float``, optional, default: ``None``
Threshold for coloring clusters. See `scipy dendrogram documentation`_ for explanation.
return_fig: ``bool``, optional, default: ``False``
Return a ``Figure`` object if ``True``; return ``None`` otherwise.
dpi``float``, optional, default: ``300.0``
The resolution in dots per inch.

Returns
-------

``Figure`` object
A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True``

Examples
--------
>>> pg.plot_dendrogram(data)
>>> pg.plot_dendrogram(data, graph_key="custom_dendrogram", label_rotation=90)

.. _scipy dendrogram documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.dendrogram.html
"""
assert graph_key in data.uns, f"Key {graph_key} not in data.uns! Either a wrong key name, or you haven't run calc_dendrogram function first."

from scipy.cluster.hierarchy import dendrogram

Z, labels = data.uns[graph_key]
fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi)
dendrogram(
Z,
ax=ax,
labels=labels,
color_threshold=color_threshold,
orientation=orientation,
)
fig.tight_layout()
plt.xticks(rotation=label_rotation, fontsize=label_fontsize)

return fig if return_fig else None


def _dendrogram_obsolete(
yihming marked this conversation as resolved.
Show resolved Hide resolved
data: Union[MultimodalData, UnimodalData, anndata.AnnData],
groupby: str,
rep: str = 'pca',
Expand Down Expand Up @@ -1684,7 +1755,7 @@ def dendrogram(

clusterer = AgglomerativeClustering(
n_clusters=n_clusters,
affinity=affinity,
metric=affinity,
linkage=linkage,
compute_full_tree=compute_full_tree,
distance_threshold=distance_threshold
Expand Down
19 changes: 17 additions & 2 deletions pegasus/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,16 @@
from .graph_operations import construct_graph
from .diffusion_map import diffmap
from .pseudotime import calc_pseudotime, infer_path
from .clustering import jump_method, louvain, leiden, spectral_louvain, spectral_leiden, cluster, split_one_cluster
from .clustering import (
jump_method,
louvain,
leiden,
spectral_louvain,
spectral_leiden,
cluster,
split_one_cluster,
calc_dendrogram,
)
from .net_regressor import net_train_and_predict
from .visualization import (
tsne,
Expand All @@ -57,7 +66,13 @@
net_umap,
net_fle,
)
from .diff_expr import de_analysis, markers, write_results_to_excel, cluster_specific_markers, run_de_analysis
from .diff_expr import (
de_analysis,
markers,
write_results_to_excel,
cluster_specific_markers,
run_de_analysis,
)
from .gradient_boosting import find_markers, run_find_markers
from .subcluster_utils import clone_subset
from .signature_score import calc_signature_score, calculate_z_score
Expand Down
105 changes: 103 additions & 2 deletions pegasus/tools/clustering.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import time
import numpy as np
import pandas as pd
from anndata import AnnData
from pandas.api.types import is_categorical_dtype
from pegasusio import MultimodalData
from pegasusio import MultimodalData, UnimodalData
from natsort import natsorted

from threadpoolctl import threadpool_limits
from scipy.sparse import issparse
from sklearn.cluster import KMeans
from typing import List, Optional, Union

from pegasus.tools import eff_n_jobs, construct_graph, calc_stat_per_batch
from pegasus.tools import eff_n_jobs, construct_graph, calc_stat_per_batch, X_from_rep, slicing

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -722,3 +724,102 @@ def split_one_cluster(
data.obs[res_label] = pd.Categorical(values = new_clust, categories = np.concatenate((cats[0:idx_cat], np.array(cats_sub), cats[idx_cat+1:])))
data.register_attr(res_label, "cluster")
del tmpdat


@timer(logger=logger)
def calc_dendrogram(
data: Union[MultimodalData, UnimodalData, AnnData],
groupby: str = "obs",
rep: Optional[str] = "pca",
genes: Optional[List[str]] = None,
on_average: bool = True,
linkage_method: str = "ward",
res_key: str = "dendrogram",
) -> None:
"""
Cluster data using hierarchical clustering algorithm.

The metric in use is a Connection Specific Index (CSI) matrix ([Suo18]_, [Bass13]_) built from the correlations between ``groupby`` attribute levels regarding the ``rep`` embedding.

Parameters
----------

data: ``MultimodalData``, ``UnimodalData``, or ``AnnData`` object
Single cell expression data.
groupby: ``str``, optional, default: ``None``
Set cluster labels in use.
If ``"obs"``, use cell names (i.e. ``data.obs_names``); if ``"var"``, use feature names (i.e. ``data.var_names``).
Otherwise, specify a categorical cell or feature attribute to use, which must exist in ``data.obs`` or ``data.var``.
rep: ``str``, optional, default: ``pca``
Cell embedding to use. If specified, it only works when ``genes`` is ``None``, and its key ``"X_"+rep`` must exist in ``data.obsm``. By default, use PCA embedding.
If ``None``, use the current count matrix ``data.X``.
genes: ``List[str]``, optional, default: ``None``
List of genes to use. Gene names must exist in ``data.var``. If set, use the counts in ``data.X`` for plotting; if ``None``, use the embedding specified in ``rep``.
on_average: ``bool``, optional, default: ``True``
If ``True``, clustering ``groupby`` levels based on their mean values. Only works when ``groupby`` is not ``None``.
linkage_method: ``str``, optional, default: ``ward``
Which linkage criterion to use, used by hierarchical clustering. Available options: ``ward`` (default), ``single``, ``complete``, ``average``, ``weighted``, ``centroid``, ``median``.
See `scipy linkage documentation`_ for details.

Returns
-------
``None``

Update ``data.uns``:
* ``data.uns[res_key]``: A tuple of the calculated linkage matrix and its corresponding labels.

Examples
--------
>>> pg.calc_dendrogram(data, groupby='leiden_labels')
>>> pg.calc_dendrogram(data, genes=['CD4', 'CD8A', 'CD8B'], on_average=False)
>>> pg.calc_dendrogram(data, groupby="var", rep=None, on_average=False)

.. _scipy linkage documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html
"""
# Set up embedding or count matrix to use
if genes is None:
if rep:
embed_df = pd.DataFrame(X_from_rep(data, rep))
else:
embed_df = pd.DataFrame(data.X.toarray() if issparse(data.X) else data.X)
else:
embed_df = pd.DataFrame(slicing(data[:, genes].X))

# Set up index
if groupby == "obs":
indices = data.obs_names
elif groupby == "var":
embed_df = embed_df.T
indices = data.var_names
elif groupby in data.obs:
indices = data.obs[groupby]
elif groupby in data.var:
embed_df = embed_df.T
indices = data.var[groupby]
else:
raise Exception(f"The groupby key {groupby} doesn't exist in data.obs or data.var!")
embed_df.set_index(indices, inplace=True)

# Use group mean if on_average is True
if on_average:
embed_df = embed_df.groupby(level=0, observed=True).mean()
if not isinstance(embed_df.index.dtype, pd.CategoricalDtype):
embed_df.index = embed_df.index.astype("category")

# Calculate Pearson's correlation between cluster labels
corr_df = pd.DataFrame(np.corrcoef(embed_df, rowvar=True), columns=embed_df.index, index=embed_df.index) # Faster than pandas corr
corr_mat = corr_df.values

from pegasus.tools.utils import calc_csi_matrix

# Calculate CSI matrix
csi_mat = calc_csi_matrix(corr_mat)
csi_df = pd.DataFrame(csi_mat, columns=corr_df.index, index=corr_df.index)

from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import squareform

dissim_df = 1 - csi_df
Z = linkage(squareform(dissim_df), method=linkage_method, optimal_ordering=True)

data.uns[res_key] = (Z, dissim_df.index.values.astype(str))
16 changes: 16 additions & 0 deletions pegasus/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from numba import njit
import pandas as pd
from pandas.api.types import is_categorical_dtype
from scipy.sparse import issparse, csr_matrix, csc_matrix
Expand Down Expand Up @@ -253,3 +254,18 @@ def largest_variance_from_random_matrix(
res = (quantiles[pval] * sigma + mu) / (ncells - 1)

return res


@njit(fastmath=True, cache=True)
def calc_csi_matrix(corr_mat):
n = corr_mat.shape[0]
csi_mat = np.eye(n, n) * n
for i in range(n - 1):
for j in range(i + 1, n):
pcc = corr_mat[i, j]
csi_mat[i, j] = np.sum(
(corr_mat[i, :] < pcc - 0.05) & (corr_mat[j, :] < pcc - 0.05)
)
csi_mat[j, i] = csi_mat[i, j]
csi_mat /= n
return csi_mat
Loading