Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deseq2 #303

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pegasus/pseudo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .convenient import markers, write_results_to_excel, volcano
from .convenient import markers, write_results_to_excel, volcano, get_original_DE_result
52 changes: 38 additions & 14 deletions pegasus/pseudo/convenient.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import logging
logger = logging.getLogger(__name__)

from pegasusio import UnimodalData, timer
from pegasusio import MultimodalData, timer
from typing import Union, Dict, Optional, Tuple



def markers(
pseudobulk: UnimodalData,
pseudobulk: MultimodalData,
head: int = None,
de_key: str = "deseq2",
alpha: float = 0.05,
Expand All @@ -23,11 +23,11 @@ def markers(

Parameters
----------
pseudobulk: ``UnimodalData``
pseudobulk: ``MultimodalData``
Pseudobulk data matrix with rows for cells and columns for genes.

head: ``int``, optional, default: ``None``
List only top ``head`` genes. If ``None``, show any DE genes.
List only top ``head`` genes. If ``None``, show all DE genes.

de_key: ``str``, optional, default, ``deseq2``
Keyword of DE result stored in ``data.varm``.
Expand All @@ -39,8 +39,8 @@ def markers(
-------
results: ``Dict[str, pd.DataFrame]``
A Python dictionary containing DE results. This dictionary contains two keywords: 'up' and 'down'.
'up' refers to up-regulated genes, which should have 'log2FoldChange' > 0.5.
'down' refers to down-regulated genes, which should have 'log2FoldChange' < 0.5.
'up' refers to up-regulated genes, which should have 'log2FoldChange' > 0.5. The genes are ranked by Wald test statistics.
'down' refers to down-regulated genes, which should have 'log2FoldChange' < 0.5. The genes are ranked by Wald test statistics.

Examples
--------
Expand All @@ -53,11 +53,11 @@ def markers(
df = pd.DataFrame(data=pseudobulk.varm[de_key], index=pseudobulk.var_names)
idx = df["padj"] <= alpha

idx_up = idx & (df["log2FoldChange"].values > 0.0)
df_up = df.loc[idx_up].sort_values(by="log2FoldChange", ascending=False, inplace=False)
idx_up = idx & (df["stat"].values > 0.0)
df_up = df.loc[idx_up].sort_values(by="stat", ascending=False, inplace=False)
res_dict["up"] = pd.DataFrame(df_up if head is None else df_up.iloc[0:head])
idx_down = idx & (df["log2FoldChange"].values < 0.0)
df_down = df.loc[idx_down].sort_values(by="log2FoldChange", ascending=True, inplace=False)
idx_down = idx & (df["stat"].values < 0.0)
df_down = df.loc[idx_down].sort_values(by="stat", ascending=True, inplace=False)
res_dict["down"] = pd.DataFrame(df_down if head is None else df_down.iloc[0:head])

return res_dict
Expand Down Expand Up @@ -149,10 +149,11 @@ def add_worksheet(


def volcano(
pseudobulk: UnimodalData,
pseudobulk: MultimodalData,
de_key: str = "deseq2",
qval_threshold: float = 0.05,
log2fc_threshold: float = 1.0,
rank_by: str = "log2fc",
top_n: int = 20,
panel_size: Optional[Tuple[float, float]] = (6, 4),
return_fig: Optional[bool] = False,
Expand All @@ -164,14 +165,16 @@ def volcano(
Parameters
-----------

pseudobulk: ``UnimodalData`` object.
pseudobulk: ``MultimodalData`` object.
Pseudobulk data matrix.
de_key: ``str``, optional, default: ``deseq2``
The varm keyword for DE results. data.varm[de_key] should store the full DE result table.
qval_threshold: ``float``, optional, default: 0.05.
Selected FDR rate. A horizontal line indicating this rate will be shown in the figure.
log2fc_threshold: ``float``, optional, default: 1.0
Log2 fold change threshold to highlight biologically interesting genes. Two vertical lines representing negative and positive log2 fold change will be shown.
rank_by: ``str``, optional, default: ``log2fc``
Rank genes by ``rank_by`` metric in the DE results. By default, rank by log2 Fold Change (``"log2fc"``); change to ``"neglog10p"`` if you want to rank genes by p-values.
top_n: ``int``, optional, default: ``20``
Number of top DE genes to show names. Genes are ranked by Log2 fold change.
panel_size: ``Tuple[float, float]``, optional, default: ``(6, 4)``
Expand Down Expand Up @@ -258,13 +261,27 @@ def volcano(
texts = []

idx = np.where(idxsig & (log2fc >= log2fc_threshold))[0]
posvec = np.argsort(log2fc[idx])[::-1][0:top_n]
if rank_by == "log2fc":
posvec = np.argsort(log2fc[idx])[::-1][0:top_n]
elif rank_by == "neglog10p":
posvec = np.argsort(neglog10p[idx])[::-1][0:top_n]
else:
import sys
logger.error(f"Invalid rank_by key! Must choose from ['log2fc', 'neglog10p']!")
sys.exit(-1)
for pos in posvec:
gid = idx[pos]
texts.append(ax.text(log2fc[gid], neglog10p[gid], gene_names[gid], fontsize=5))

idx = np.where(idxsig & (log2fc <= -log2fc_threshold))[0]
posvec = np.argsort(log2fc[idx])[0:top_n]
if rank_by == "log2fc":
posvec = np.argsort(log2fc[idx])[0:top_n]
elif rank_by == "neglog10p":
posvec = np.argsort(neglog10p[idx])[::-1][0:top_n]
else:
import sys
logger.error(f"Invalid rank_by key! Must choose from ['log2fc', 'neglog10p']!")
sys.exit(-1)
for pos in posvec:
gid = idx[pos]
texts.append(ax.text(log2fc[gid], neglog10p[gid], gene_names[gid], fontsize=5))
Expand All @@ -273,3 +290,10 @@ def volcano(
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5))

return fig if return_fig else None


def get_original_DE_result(
data: MultimodalData,
de_key: str = "deseq2",
) -> pd.DataFrame:
return pd.DataFrame(data.varm[de_key], index=data.var_names)
70 changes: 48 additions & 22 deletions pegasus/tools/pseudobulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def pseudobulk(
data: Union[MultimodalData, UnimodalData],
groupby: str,
attrs: Optional[Union[List[str], str]] = None,
mat_key: str = "counts",
mat_key: Optional[str] = None,
condition: Optional[str] = None,
) -> MultimodalData:
"""Generate Pseudo-bulk count matrices.
Expand All @@ -53,9 +53,9 @@ def pseudobulk(
Notice that for a categorical attribute, each pseudo-bulk's value is the one of highest frequency among its cells,
and for a numeric attribute, each pseudo-bulk's value is the mean among its cells.

mat_key: ``str``, optional, default: ``counts``
mat_key: ``str``, optional, default: ``None``
Specify the single-cell count matrix used for aggregating pseudo-bulk counts:
If specified, use the count matrix with key ``mat_key`` from matrices of ``data``; otherwise, default is ``counts``.
If specified, use the count matrix with key ``mat_key`` from matrices of ``data``; otherwise, first look for key ``counts``, then for ``raw.X`` if not existing.

condition: ``str``, optional, default: ``None``
If set, additionally generate pseudo-bulk matrices per condition specified in ``data.obs[condition]``.
Expand All @@ -74,9 +74,21 @@ def pseudobulk(
--------
>>> pg.pseudobulk(data, groupby="Channel")
"""
if mat_key is None:
if "counts" in data._unidata.matrices:
mat_key = "counts"
elif "raw.X" in data._unidata.matrices:
mat_key = "raw.X"
else:
import sys
logger.error("No matrix with default key found in data! Please specify an explicit matrix key!")
sys.exit(-1)
X = data.get_matrix(mat_key)

assert groupby in data.obs.columns, f"Sample key '{groupby}' must exist in data.obs!"
if groupby not in data.obs.columns:
import sys
logger.error(f"Sample key '{groupby}' must exist in data.obs!")
sys.exit(-1)

sample_vec = (
data.obs[groupby]
Expand All @@ -96,9 +108,10 @@ def pseudobulk(
if isinstance(attrs, str):
attrs = [attrs]
for attr in attrs:
assert (
attr in data.obs.columns
), f"Cell attribute key '{attr}' must exist in data.obs!"
if attr not in data.obs.columns:
import sys
logger.error(f"Cell attribute key '{attr}' must exist in data.obs!")
sys.exit(-1)

for bulk in bulk_list:
df_bulk = df_barcode.loc[df_barcode[groupby] == bulk]
Expand All @@ -124,9 +137,10 @@ def pseudobulk(
df_feature["featureid"] = data.var["featureid"]

if condition is not None:
assert (
condition in data.obs.columns
), f"Condition key '{attr}' must exist in data.obs!"
if condition not in data.obs.columns:
import sys
logger.error(f"Condition key '{attr}' must exist in data.obs!")
sys.exit(-1)

cluster_list = data.obs[condition].astype("category").cat.categories
for cls in cluster_list:
Expand All @@ -138,7 +152,7 @@ def pseudobulk(
barcode_metadata=df_pseudobulk,
feature_metadata=df_feature,
matrices=mat_dict,
genome=groupby,
genome=data.get_genome(),
modality="pseudobulk",
cur_matrix="counts",
)
Expand All @@ -148,13 +162,14 @@ def pseudobulk(

@timer(logger=logger)
def deseq2(
pseudobulk: Union[MultimodalData, UnimodalData],
pseudobulk: MultimodalData,
design: str,
contrast: Tuple[str, str, str],
backend: str = "pydeseq2",
de_key: str = "deseq2",
alpha: float = 0.05,
compute_all: bool = False,
verbose: bool = True,
n_jobs: int = -1,
) -> None:
"""Perform Differential Expression (DE) Analysis using DESeq2 on pseduobulk data.
Expand Down Expand Up @@ -185,6 +200,9 @@ def deseq2(
compute_all: ``bool``, optional, default: ``False``
If performing DE analysis on all count matrices. By default (``compute_all=False``), only apply DE analysis to the default count matrix ``counts``.

verbose: ``bool``, optional, default: ``True``
If showing DESeq2 status updates during fit. Only works when ``backend="pydeseq2"``.

n_jobs: ``int``, optional, default: ``-1``
Number of threads to use. If ``-1``, use all physical CPU cores. This only works when ``backend="pydeseq2"`.

Expand All @@ -204,19 +222,20 @@ def deseq2(
mat_keys = ['counts'] if not compute_all else pseudobulk.list_keys()
for mat_key in mat_keys:
if backend == "pydeseq2":
_run_pydeseq2(pseudobulk=pseudobulk, mat_key=mat_key, design_factors=design, contrast=contrast, de_key=de_key, alpha=alpha, n_jobs=n_jobs)
_run_pydeseq2(pseudobulk=pseudobulk, mat_key=mat_key, design_factors=design, contrast=contrast, de_key=de_key, alpha=alpha, n_jobs=n_jobs, verbose=verbose)
else:
_run_rdeseq2(pseudobulk=pseudobulk, mat_key=mat_key, design=design, contrast=contrast, de_key=de_key, alpha=alpha)


def _run_pydeseq2(
pseudobulk: Union[MultimodalData, UnimodalData],
pseudobulk: MultimodalData,
mat_key: str,
design_factors: Union[str, List[str]],
contrast: Tuple[str, str, str],
de_key: str,
alpha: float,
n_jobs: int,
verbose: bool,
) -> None:
try:
from pydeseq2.dds import DeseqDataSet
Expand All @@ -228,10 +247,16 @@ def _run_pydeseq2(
sys.exit(-1)

if isinstance(design_factors, str):
assert design_factors in pseudobulk.obs.columns, f"The design factor {design_factors} does not exist in data.obs!"
if design_factors not in pseudobulk.obs.columns:
import sys
logger.error(f"The design factor {design_factors} does not exist in data.obs!")
sys.exit(-1)
else:
for factor in design_factors:
assert factor in pseudobulk.obs.columns, f"The design factor {factor} does not exist in data.obs!"
if factor not in pseudobulk.obs.columns:
import sys
logger.error(f"The design factor {factor} does not exist in data.obs!")
sys.exit(-1)

counts_df = pd.DataFrame(pseudobulk.get_matrix(mat_key), index=pseudobulk.obs_names, columns=pseudobulk.var_names)
metadata = pseudobulk.obs
Expand All @@ -242,19 +267,17 @@ def _run_pydeseq2(
counts=counts_df,
metadata=metadata,
design_factors=design_factors,
refit_cooks=True,
inference=inference,
quiet=True,
quiet=not verbose,
)
dds.deseq2()

stat_res = DeseqStats(
dds,
contrast=contrast,
cooks_filter=True,
alpha=alpha,
inference=inference,
quiet=True,
quiet=not verbose,
)
stat_res.summary()
res_key = de_key if mat_key == "counts" else mat_key.removesuffix(".X") + "." + de_key
Expand All @@ -263,7 +286,7 @@ def _run_pydeseq2(


def _run_rdeseq2(
pseudobulk: Union[MultimodalData, UnimodalData],
pseudobulk: MultimodalData,
mat_key: str,
design: str,
contrast: Tuple[str, str, str],
Expand Down Expand Up @@ -293,7 +316,10 @@ def _run_rdeseq2(
logger.error(text)
sys.exit(-1)

assert design.strip().startswith("~"), f"Design '{design}' is not a valid R formula! Valid examples: '~var', '~var1+var2', '~var1+var2+var1:var2'."
if not design.strip().startswith("~"):
import sys
logger.error(f"Design '{design}' is not a valid R formula! Valid examples: '~var', '~var1+var2', '~var1+var2+var1:var2'.")
sys.exit(-1)

import math
to_dataframe = ro.r('function(x) data.frame(x)')
Expand Down
Loading