Skip to content

Commit

Permalink
Merge pull request #303 from lilab-bcb/fix-deseq2
Browse files Browse the repository at this point in the history
Fix deseq2
  • Loading branch information
yihming authored Jun 5, 2024
2 parents eaeeee6 + bda3a26 commit 4122e5d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pegasus/pseudo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .convenient import markers, write_results_to_excel, volcano
from .convenient import markers, write_results_to_excel, volcano, get_original_DE_result
52 changes: 38 additions & 14 deletions pegasus/pseudo/convenient.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import logging
logger = logging.getLogger(__name__)

from pegasusio import UnimodalData, timer
from pegasusio import MultimodalData, timer
from typing import Union, Dict, Optional, Tuple



def markers(
pseudobulk: UnimodalData,
pseudobulk: MultimodalData,
head: int = None,
de_key: str = "deseq2",
alpha: float = 0.05,
Expand All @@ -23,11 +23,11 @@ def markers(
Parameters
----------
pseudobulk: ``UnimodalData``
pseudobulk: ``MultimodalData``
Pseudobulk data matrix with rows for cells and columns for genes.
head: ``int``, optional, default: ``None``
List only top ``head`` genes. If ``None``, show any DE genes.
List only top ``head`` genes. If ``None``, show all DE genes.
de_key: ``str``, optional, default, ``deseq2``
Keyword of DE result stored in ``data.varm``.
Expand All @@ -39,8 +39,8 @@ def markers(
-------
results: ``Dict[str, pd.DataFrame]``
A Python dictionary containing DE results. This dictionary contains two keywords: 'up' and 'down'.
'up' refers to up-regulated genes, which should have 'log2FoldChange' > 0.5.
'down' refers to down-regulated genes, which should have 'log2FoldChange' < 0.5.
'up' refers to up-regulated genes, which should have 'log2FoldChange' > 0.5. The genes are ranked by Wald test statistics.
'down' refers to down-regulated genes, which should have 'log2FoldChange' < 0.5. The genes are ranked by Wald test statistics.
Examples
--------
Expand All @@ -53,11 +53,11 @@ def markers(
df = pd.DataFrame(data=pseudobulk.varm[de_key], index=pseudobulk.var_names)
idx = df["padj"] <= alpha

idx_up = idx & (df["log2FoldChange"].values > 0.0)
df_up = df.loc[idx_up].sort_values(by="log2FoldChange", ascending=False, inplace=False)
idx_up = idx & (df["stat"].values > 0.0)
df_up = df.loc[idx_up].sort_values(by="stat", ascending=False, inplace=False)
res_dict["up"] = pd.DataFrame(df_up if head is None else df_up.iloc[0:head])
idx_down = idx & (df["log2FoldChange"].values < 0.0)
df_down = df.loc[idx_down].sort_values(by="log2FoldChange", ascending=True, inplace=False)
idx_down = idx & (df["stat"].values < 0.0)
df_down = df.loc[idx_down].sort_values(by="stat", ascending=True, inplace=False)
res_dict["down"] = pd.DataFrame(df_down if head is None else df_down.iloc[0:head])

return res_dict
Expand Down Expand Up @@ -149,10 +149,11 @@ def add_worksheet(


def volcano(
pseudobulk: UnimodalData,
pseudobulk: MultimodalData,
de_key: str = "deseq2",
qval_threshold: float = 0.05,
log2fc_threshold: float = 1.0,
rank_by: str = "log2fc",
top_n: int = 20,
panel_size: Optional[Tuple[float, float]] = (6, 4),
return_fig: Optional[bool] = False,
Expand All @@ -164,14 +165,16 @@ def volcano(
Parameters
-----------
pseudobulk: ``UnimodalData`` object.
pseudobulk: ``MultimodalData`` object.
Pseudobulk data matrix.
de_key: ``str``, optional, default: ``deseq2``
The varm keyword for DE results. data.varm[de_key] should store the full DE result table.
qval_threshold: ``float``, optional, default: 0.05.
Selected FDR rate. A horizontal line indicating this rate will be shown in the figure.
log2fc_threshold: ``float``, optional, default: 1.0
Log2 fold change threshold to highlight biologically interesting genes. Two vertical lines representing negative and positive log2 fold change will be shown.
rank_by: ``str``, optional, default: ``log2fc``
Rank genes by ``rank_by`` metric in the DE results. By default, rank by log2 Fold Change (``"log2fc"``); change to ``"neglog10p"`` if you want to rank genes by p-values.
top_n: ``int``, optional, default: ``20``
Number of top DE genes to show names. Genes are ranked by Log2 fold change.
panel_size: ``Tuple[float, float]``, optional, default: ``(6, 4)``
Expand Down Expand Up @@ -258,13 +261,27 @@ def volcano(
texts = []

idx = np.where(idxsig & (log2fc >= log2fc_threshold))[0]
posvec = np.argsort(log2fc[idx])[::-1][0:top_n]
if rank_by == "log2fc":
posvec = np.argsort(log2fc[idx])[::-1][0:top_n]
elif rank_by == "neglog10p":
posvec = np.argsort(neglog10p[idx])[::-1][0:top_n]
else:
import sys
logger.error(f"Invalid rank_by key! Must choose from ['log2fc', 'neglog10p']!")
sys.exit(-1)
for pos in posvec:
gid = idx[pos]
texts.append(ax.text(log2fc[gid], neglog10p[gid], gene_names[gid], fontsize=5))

idx = np.where(idxsig & (log2fc <= -log2fc_threshold))[0]
posvec = np.argsort(log2fc[idx])[0:top_n]
if rank_by == "log2fc":
posvec = np.argsort(log2fc[idx])[0:top_n]
elif rank_by == "neglog10p":
posvec = np.argsort(neglog10p[idx])[::-1][0:top_n]
else:
import sys
logger.error(f"Invalid rank_by key! Must choose from ['log2fc', 'neglog10p']!")
sys.exit(-1)
for pos in posvec:
gid = idx[pos]
texts.append(ax.text(log2fc[gid], neglog10p[gid], gene_names[gid], fontsize=5))
Expand All @@ -273,3 +290,10 @@ def volcano(
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5))

return fig if return_fig else None


def get_original_DE_result(
data: MultimodalData,
de_key: str = "deseq2",
) -> pd.DataFrame:
return pd.DataFrame(data.varm[de_key], index=data.var_names)
70 changes: 48 additions & 22 deletions pegasus/tools/pseudobulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def pseudobulk(
data: Union[MultimodalData, UnimodalData],
groupby: str,
attrs: Optional[Union[List[str], str]] = None,
mat_key: str = "counts",
mat_key: Optional[str] = None,
condition: Optional[str] = None,
) -> MultimodalData:
"""Generate Pseudo-bulk count matrices.
Expand All @@ -53,9 +53,9 @@ def pseudobulk(
Notice that for a categorical attribute, each pseudo-bulk's value is the one of highest frequency among its cells,
and for a numeric attribute, each pseudo-bulk's value is the mean among its cells.
mat_key: ``str``, optional, default: ``counts``
mat_key: ``str``, optional, default: ``None``
Specify the single-cell count matrix used for aggregating pseudo-bulk counts:
If specified, use the count matrix with key ``mat_key`` from matrices of ``data``; otherwise, default is ``counts``.
If specified, use the count matrix with key ``mat_key`` from matrices of ``data``; otherwise, first look for key ``counts``, then for ``raw.X`` if not existing.
condition: ``str``, optional, default: ``None``
If set, additionally generate pseudo-bulk matrices per condition specified in ``data.obs[condition]``.
Expand All @@ -74,9 +74,21 @@ def pseudobulk(
--------
>>> pg.pseudobulk(data, groupby="Channel")
"""
if mat_key is None:
if "counts" in data._unidata.matrices:
mat_key = "counts"
elif "raw.X" in data._unidata.matrices:
mat_key = "raw.X"
else:
import sys
logger.error("No matrix with default key found in data! Please specify an explicit matrix key!")
sys.exit(-1)
X = data.get_matrix(mat_key)

assert groupby in data.obs.columns, f"Sample key '{groupby}' must exist in data.obs!"
if groupby not in data.obs.columns:
import sys
logger.error(f"Sample key '{groupby}' must exist in data.obs!")
sys.exit(-1)

sample_vec = (
data.obs[groupby]
Expand All @@ -96,9 +108,10 @@ def pseudobulk(
if isinstance(attrs, str):
attrs = [attrs]
for attr in attrs:
assert (
attr in data.obs.columns
), f"Cell attribute key '{attr}' must exist in data.obs!"
if attr not in data.obs.columns:
import sys
logger.error(f"Cell attribute key '{attr}' must exist in data.obs!")
sys.exit(-1)

for bulk in bulk_list:
df_bulk = df_barcode.loc[df_barcode[groupby] == bulk]
Expand All @@ -124,9 +137,10 @@ def pseudobulk(
df_feature["featureid"] = data.var["featureid"]

if condition is not None:
assert (
condition in data.obs.columns
), f"Condition key '{attr}' must exist in data.obs!"
if condition not in data.obs.columns:
import sys
logger.error(f"Condition key '{attr}' must exist in data.obs!")
sys.exit(-1)

cluster_list = data.obs[condition].astype("category").cat.categories
for cls in cluster_list:
Expand All @@ -138,7 +152,7 @@ def pseudobulk(
barcode_metadata=df_pseudobulk,
feature_metadata=df_feature,
matrices=mat_dict,
genome=groupby,
genome=data.get_genome(),
modality="pseudobulk",
cur_matrix="counts",
)
Expand All @@ -148,13 +162,14 @@ def pseudobulk(

@timer(logger=logger)
def deseq2(
pseudobulk: Union[MultimodalData, UnimodalData],
pseudobulk: MultimodalData,
design: str,
contrast: Tuple[str, str, str],
backend: str = "pydeseq2",
de_key: str = "deseq2",
alpha: float = 0.05,
compute_all: bool = False,
verbose: bool = True,
n_jobs: int = -1,
) -> None:
"""Perform Differential Expression (DE) Analysis using DESeq2 on pseduobulk data.
Expand Down Expand Up @@ -185,6 +200,9 @@ def deseq2(
compute_all: ``bool``, optional, default: ``False``
If performing DE analysis on all count matrices. By default (``compute_all=False``), only apply DE analysis to the default count matrix ``counts``.
verbose: ``bool``, optional, default: ``True``
If showing DESeq2 status updates during fit. Only works when ``backend="pydeseq2"``.
n_jobs: ``int``, optional, default: ``-1``
Number of threads to use. If ``-1``, use all physical CPU cores. This only works when ``backend="pydeseq2"`.
Expand All @@ -204,19 +222,20 @@ def deseq2(
mat_keys = ['counts'] if not compute_all else pseudobulk.list_keys()
for mat_key in mat_keys:
if backend == "pydeseq2":
_run_pydeseq2(pseudobulk=pseudobulk, mat_key=mat_key, design_factors=design, contrast=contrast, de_key=de_key, alpha=alpha, n_jobs=n_jobs)
_run_pydeseq2(pseudobulk=pseudobulk, mat_key=mat_key, design_factors=design, contrast=contrast, de_key=de_key, alpha=alpha, n_jobs=n_jobs, verbose=verbose)
else:
_run_rdeseq2(pseudobulk=pseudobulk, mat_key=mat_key, design=design, contrast=contrast, de_key=de_key, alpha=alpha)


def _run_pydeseq2(
pseudobulk: Union[MultimodalData, UnimodalData],
pseudobulk: MultimodalData,
mat_key: str,
design_factors: Union[str, List[str]],
contrast: Tuple[str, str, str],
de_key: str,
alpha: float,
n_jobs: int,
verbose: bool,
) -> None:
try:
from pydeseq2.dds import DeseqDataSet
Expand All @@ -228,10 +247,16 @@ def _run_pydeseq2(
sys.exit(-1)

if isinstance(design_factors, str):
assert design_factors in pseudobulk.obs.columns, f"The design factor {design_factors} does not exist in data.obs!"
if design_factors not in pseudobulk.obs.columns:
import sys
logger.error(f"The design factor {design_factors} does not exist in data.obs!")
sys.exit(-1)
else:
for factor in design_factors:
assert factor in pseudobulk.obs.columns, f"The design factor {factor} does not exist in data.obs!"
if factor not in pseudobulk.obs.columns:
import sys
logger.error(f"The design factor {factor} does not exist in data.obs!")
sys.exit(-1)

counts_df = pd.DataFrame(pseudobulk.get_matrix(mat_key), index=pseudobulk.obs_names, columns=pseudobulk.var_names)
metadata = pseudobulk.obs
Expand All @@ -242,19 +267,17 @@ def _run_pydeseq2(
counts=counts_df,
metadata=metadata,
design_factors=design_factors,
refit_cooks=True,
inference=inference,
quiet=True,
quiet=not verbose,
)
dds.deseq2()

stat_res = DeseqStats(
dds,
contrast=contrast,
cooks_filter=True,
alpha=alpha,
inference=inference,
quiet=True,
quiet=not verbose,
)
stat_res.summary()
res_key = de_key if mat_key == "counts" else mat_key.removesuffix(".X") + "." + de_key
Expand All @@ -263,7 +286,7 @@ def _run_pydeseq2(


def _run_rdeseq2(
pseudobulk: Union[MultimodalData, UnimodalData],
pseudobulk: MultimodalData,
mat_key: str,
design: str,
contrast: Tuple[str, str, str],
Expand Down Expand Up @@ -293,7 +316,10 @@ def _run_rdeseq2(
logger.error(text)
sys.exit(-1)

assert design.strip().startswith("~"), f"Design '{design}' is not a valid R formula! Valid examples: '~var', '~var1+var2', '~var1+var2+var1:var2'."
if not design.strip().startswith("~"):
import sys
logger.error(f"Design '{design}' is not a valid R formula! Valid examples: '~var', '~var1+var2', '~var1+var2+var1:var2'.")
sys.exit(-1)

import math
to_dataframe = ro.r('function(x) data.frame(x)')
Expand Down

0 comments on commit 4122e5d

Please sign in to comment.