diff --git a/.github/workflows/pr_validation.yml b/.github/workflows/pr_validation.yml index 4c14dfcb..4c55eebc 100644 --- a/.github/workflows/pr_validation.yml +++ b/.github/workflows/pr_validation.yml @@ -17,8 +17,6 @@ jobs: strategy: matrix: include: - - os: ubuntu-latest - python: "3.8" - os: ubuntu-latest python: "3.9" - os: ubuntu-latest diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index 665ef5b9..879101ba 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import sys import time import warnings @@ -21,6 +22,7 @@ from pydeseq2.preprocessing import deseq2_norm_transform from pydeseq2.utils import build_design_matrix from pydeseq2.utils import dispersion_trend +from pydeseq2.utils import local_trend_fit from pydeseq2.utils import make_scatter from pydeseq2.utils import mean_absolute_deviation from pydeseq2.utils import n_or_more_replicates @@ -71,6 +73,11 @@ class DeseqDataSet(ad.AnnData): specifying the factor of interest and the reference (control) level against which we're testing, e.g. ``["condition", "A"]``. (default: ``None``). + disp_function_type : str + Either "parametric", "local" or "mean", for the type of fitting of dispersions + trend curve.If "parametric" is selected but the fitting fails, it will switch to + "local". (default: ``"parametric"``). + min_mu : float Threshold for mean estimates. (default: ``0.5``). @@ -181,6 +188,7 @@ def __init__( design_factors: Union[str, List[str]] = "condition", continuous_factors: Optional[List[str]] = None, ref_level: Optional[List[str]] = None, + disp_function_type: Literal["parametric", "local", "mean"] = "parametric", min_mu: float = 0.5, min_disp: float = 1e-8, max_disp: float = 10.0, @@ -266,6 +274,7 @@ def __init__( # Check that the design matrix has full rank self._check_full_rank_design() + self.disp_function_type = disp_function_type self.min_mu = min_mu self.min_disp = min_disp self.max_disp = np.maximum(max_disp, self.n_obs) @@ -343,7 +352,7 @@ def vst_fit( if use_design: # Check that the dispersion trend curve was fitted. If not, fit it. # This will call previous functions in a cascade. - if "trend_coeffs" not in self.uns: + if "disp_function" not in self.uns: self.fit_dispersion_trend() else: # Reduce the design matrix to an intercept and reconstruct at the end @@ -477,7 +486,7 @@ def fit_size_factors( warnings.warn( "Every gene contains at least one zero, " "cannot compute log geometric means. Switching to iterative mode.", - RuntimeWarning, + UserWarning, stacklevel=2, ) self._fit_iterate_size_factors() @@ -567,72 +576,58 @@ def fit_genewise_dispersions(self) -> None: self.varm["_genewise_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_ def fit_dispersion_trend(self) -> None: - r"""Fit the dispersion trend coefficients. + r"""Fit the dispersion trend curve. - :math:`f(\mu) = \alpha_1/\mu + a_0`. + Three methods are available, depending on the ``disp_function_type`` attribute: + "parametric", "local" and "mean". """ - # Check that genewise dispersions are available. If not, compute them. - if "genewise_dispersions" not in self.varm: - self.fit_genewise_dispersions() - - if not self.quiet: - print("Fitting dispersion trend curve...", file=sys.stderr) - start = time.time() - self.varm["_normed_means"] = self.layers["normed_counts"].mean(0) - - # Exclude all-zero counts - targets = pd.Series( - self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(), - index=self.non_zero_genes, - ) - covariates = pd.Series( - 1 / self[:, self.non_zero_genes].varm["_normed_means"], - index=self.non_zero_genes, - ) + if self.disp_function_type == "parametric": + try: + self._fit_parametric_trend() + except RuntimeError: + warnings.warn( + "The dispersion trend curve fitting did not converge. " + "Switching to a mean-based dispersion trend.", + UserWarning, + stacklevel=2, + ) + self.disp_function_type = "local" - for gene in self.non_zero_genes: - if ( - np.isinf(covariates.loc[gene]).any() - or np.isnan(covariates.loc[gene]).any() - ): - targets.drop(labels=[gene], inplace=True) - covariates.drop(labels=[gene], inplace=True) + if (self.uns["trend_coeffs"] == 0).any(): + warnings.warn( + f"self.disp_function_type={self.disp_function_type}, but the " + f"dispersion trend was not well captured by the function: " + f"y = a / x + b. Switching to local regression.", + UserWarning, + stacklevel=2, + ) + self.disp_function_type = "local" + del self.uns["trend_coeffs"] - # Initialize coefficients - old_coeffs = pd.Series([0.1, 0.1]) - coeffs = pd.Series([1.0, 1.0]) + if self.disp_function_type == "local": + try: + self._fit_local_trend() + except (ValueError, RuntimeError): + print("Local trend fit did not converge, switching to mean fit.") + self.disp_function_type = "mean" - while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6: - old_coeffs = coeffs - coeffs, predictions = self.inference.dispersion_trend_gamma_glm( - covariates, targets - ) - # Filter out genes that are too far away from the curve before refitting - pred_ratios = ( - self[:, covariates.index].varm["genewise_dispersions"] / predictions - ) + if self.disp_function_type == "mean": + self._fit_mean_trend() - targets.drop( - targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, - inplace=True, - ) - covariates.drop( - covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, - inplace=True, + if self.disp_function_type not in ["parametric", "local", "mean"]: + raise NotImplementedError( + f"Unknown disp_function_type: {self.disp_function_type}. " + "Expected 'parametric', 'local' or 'mean'." ) - end = time.time() - - if not self.quiet: - print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr) - - self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"]) - - self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN) - self.varm["fitted_dispersions"][self.varm["non_zero"]] = dispersion_trend( - self.varm["_normed_means"][self.varm["non_zero"]], - self.uns["trend_coeffs"], - ) + def disp_function(self, x): + """Return the dispersion trend function at x.""" + if self.disp_function_type == "parametric": + return dispersion_trend(x, self.uns["trend_coeffs"]) + elif self.disp_function_type == "local": + return np.exp(self.uns["loess"].predict(np.log(x)).values) + elif self.disp_function_type == "mean": + return self.uns["mean_disp"] def fit_dispersion_prior(self) -> None: """Fit dispersion variance priors and standard deviation of log-residuals. @@ -859,6 +854,142 @@ def _fit_MoM_dispersions(self) -> None: alpha_hat, self.min_disp, self.max_disp ) + def _fit_parametric_trend(self) -> None: + r"""Fit the dispersion trend coefficients. + + :math:`f(\mu) = \alpha_1/\mu + a_0`. + """ + # Check that genewise dispersions are available. If not, compute them. + if "genewise_dispersions" not in self.varm: + self.fit_genewise_dispersions() + + if not self.quiet: + print("Fitting dispersion trend curve...", file=sys.stderr) + start = time.time() + self.varm["_normed_means"] = self.layers["normed_counts"].mean(0) + + # Exclude all-zero counts + targets = pd.Series( + self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(), + index=self.non_zero_genes, + ) + covariates = pd.Series( + 1 / self[:, self.non_zero_genes].varm["_normed_means"], + index=self.non_zero_genes, + ) + + for gene in self.non_zero_genes: + if ( + np.isinf(covariates.loc[gene]).any() + or np.isnan(covariates.loc[gene]).any() + ): + targets.drop(labels=[gene], inplace=True) + covariates.drop(labels=[gene], inplace=True) + + # Initialize coefficients + old_coeffs = pd.Series([0.1, 0.1]) + coeffs = pd.Series([1.0, 1.0]) + + try: + while (coeffs > 0).all() and ( + np.log(np.abs(coeffs / old_coeffs)) ** 2 + ).sum() >= 1e-6: + old_coeffs = coeffs + coeffs, predictions = self.inference.dispersion_trend_gamma_glm( + covariates, targets + ) + # Filter out genes that are too far away from the curve before refitting + pred_ratios = ( + self[:, covariates.index].varm["genewise_dispersions"] / predictions + ) + + targets.drop( + targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, + inplace=True, + ) + covariates.drop( + covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, + inplace=True, + ) + except RuntimeError as e: + raise e + + end = time.time() + + if not self.quiet: + print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr) + + self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"]) + + self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN) + self.varm["fitted_dispersions"][self.varm["non_zero"]] = self.disp_function( + self.varm["_normed_means"][self.varm["non_zero"]] + ) + + def _fit_local_trend(self) -> None: + r"""Fit the dispersion trend curve using local regression.""" + # Check that genewise dispersions are available. If not, compute them. + warnings.warn( + "Running local trend fit will make the DeseqDataSet object unpicklable.", + UserWarning, + stacklevel=2, + ) + + if "genewise_dispersions" not in self.varm: + self.fit_genewise_dispersions() + + if not self.quiet: + print("Fitting dispersion trend curve...", file=sys.stderr) + start = time.time() + self.varm["_normed_means"] = self.layers["normed_counts"].mean(0) + + genes_to_fit = self.varm["non_zero"] & ( + self.varm["genewise_dispersions"] >= 10 * self.min_disp + ) + + if len(genes_to_fit) == 0: + raise ValueError("No genes to fit: all dispersions are below 10 * min_disp") + + means = self.varm["_normed_means"][genes_to_fit] + dispersions = self.varm["genewise_dispersions"][genes_to_fit] + + # Run local trend fit in a separate process to avoid segmentation faults + q = mp.Manager().Queue() + p = mp.Process(target=local_trend_fit, args=(means, dispersions, q)) + p.start() + p.join() + + if p.exitcode != 0: + raise RuntimeError + else: + # The code rand without fault, so we can now run it locally. + # Returning the loess object in the first step would be more straightforward, + # But it is unfortunately unpicklable. + + lo = local_trend_fit(means, dispersions) + + end = time.time() + + if not self.quiet: + print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr) + + self.uns["loess"] = lo + self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN) + self.varm["fitted_dispersions"][self.varm["non_zero"]] = self.disp_function( + self.varm["_normed_means"][self.varm["non_zero"]] + ) + + def _fit_mean_trend(self): + """Fit the dispersion trend curve using the mean of gene-wise dispersions.""" + mean_disp = trim_mean( + self.varm["genewise_dispersions"][ + self.varm["genewise_dispersions"] > 10 * self.min_disp + ], + proportiontocut=0.001, + ) + + self.varm["fitted_dispersions"] = np.full(self.n_vars, mean_disp) + def plot_dispersions( self, log: bool = True, save_path: Optional[str] = None, **kwargs ) -> None: @@ -988,6 +1119,7 @@ def _refit_without_outliers( min_replicates=self.min_replicates, beta_tol=self.beta_tol, inference=self.inference, + disp_function_type=self.disp_function_type, ) # Use the same size factors @@ -1001,11 +1133,10 @@ def _refit_without_outliers( # Compute trend dispersions. # Note: the trend curve is not refitted. - sub_dds.uns["trend_coeffs"] = self.uns["trend_coeffs"] sub_dds.varm["_normed_means"] = sub_dds.layers["normed_counts"].mean(0) - sub_dds.varm["fitted_dispersions"] = dispersion_trend( - sub_dds.varm["_normed_means"], - sub_dds.uns["trend_coeffs"], + + sub_dds.varm["fitted_dispersions"] = self.disp_function( + sub_dds.varm["_normed_means"][sub_dds.varm["non_zero"]] ) # Estimate MAP dispersions. @@ -1089,6 +1220,13 @@ def objective(p): & self.varm["non_zero"] ] + if len(use_for_mean_genes) == 0: + print( + "No genes have a dispersion above 10 * min_disp in " + "_fit_iterate_size_factors." + ) + break + mean_disp = trimmed_mean( self[:, use_for_mean_genes].varm["genewise_dispersions"], trim=0.001 ) diff --git a/pydeseq2/default_inference.py b/pydeseq2/default_inference.py index 29ebb65c..975375cd 100644 --- a/pydeseq2/default_inference.py +++ b/pydeseq2/default_inference.py @@ -1,22 +1,17 @@ -import warnings from typing import Literal from typing import Optional from typing import Tuple import numpy as np import pandas as pd -import statsmodels.api as sm # type: ignore from joblib import Parallel # type: ignore from joblib import delayed from joblib import parallel_backend -from statsmodels.tools.sm_exceptions import DomainWarning # type: ignore +from scipy.optimize import minimize # type: ignore from pydeseq2 import inference from pydeseq2 import utils -# Ignore DomainWarning raised by statsmodels when fitting a Gamma GLM with identity link. -warnings.simplefilter("ignore", DomainWarning) - class DefaultInference(inference.Inference): """Default DESeq2-related inference methods, using scipy/sklearn/numpy. @@ -207,17 +202,34 @@ def wald_test( # noqa: D102 def dispersion_trend_gamma_glm( # noqa: D102 self, covariates: pd.Series, targets: pd.Series ) -> Tuple[np.ndarray, np.ndarray]: - covariates_w_intercept = sm.add_constant(covariates) - targets_fit = targets.values + covariates_w_intercept = covariates.to_frame() + covariates_w_intercept.insert(0, "intercept", 1) covariates_fit = covariates_w_intercept.values - glm_gamma = sm.GLM( - targets_fit, - covariates_fit, - family=sm.families.Gamma(link=sm.families.links.identity()), + targets_fit = targets.values + + def loss(coeffs): + mu = covariates_fit @ coeffs + return (targets_fit / mu + np.log(mu)).mean() + + def grad(coeffs): + mu = covariates_fit @ coeffs + return -( + ((targets_fit / mu - 1)[:, None] * covariates_fit) / mu[:, None] + ).mean(0) + + res = minimize( + loss, + x0=np.array([1.0, 1.0]), + jac=grad, + method="L-BFGS-B", + bounds=[(0, np.inf)], ) - res = glm_gamma.fit() - coeffs = res.params - return (coeffs, covariates_fit @ coeffs) + + if not res.success: + raise RuntimeError("Gamma GLM optimization failed.") + + coeffs = res.x + return coeffs, covariates_fit @ coeffs def lfc_shrink_nbinom_glm( # noqa: D102 self, diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 0f04e0f5..3a46d682 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1,6 +1,7 @@ import multiprocessing import warnings from math import floor +from multiprocessing import Queue from pathlib import Path from typing import List from typing import Literal @@ -18,6 +19,7 @@ from scipy.special import polygamma # type: ignore from scipy.stats import norm # type: ignore from sklearn.linear_model import LinearRegression # type: ignore +from skmisc.loess import loess import pydeseq2 from pydeseq2.grid_search import grid_fit_alpha @@ -724,6 +726,20 @@ def dloss(log_alpha: float) -> float: ) +def local_trend_fit(means, dispersions, q: Optional[Queue] = None): + """Run a wrapper for local trend fit to catch segfaults from scikit-misc.""" + lo = loess( + x=np.log(means), + y=np.log(dispersions), + weights=means, + surface="direct", # to allow extrapolation + ) + + lo.fit() + + return lo + + def trimmed_mean(x, trim: float = 0.1, **kwargs) -> Union[float, np.ndarray]: """Return trimmed mean. diff --git a/setup.py b/setup.py index 5bb3269c..1870950e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name="pydeseq2", version=about["__version__"], - python_requires=">=3.8.0", + python_requires=">=3.9.0", license="MIT", description="A python implementation of DESeq2.", long_description=readme, @@ -31,6 +31,7 @@ "numpy>=1.23.0", "pandas>=1.4.0", "scikit-learn>=1.1.0", + "scikit-misc>=0.3.1", "scipy>=1.8.0", "statsmodels", "matplotlib>=3.6.2", # not sure why sphinx_gallery does not work without it diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 5a2b8a71..b93a853e 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -468,7 +468,7 @@ def test_zero_inflated(): counts_df.iloc[idx, :] = 0 dds = DeseqDataSet(counts=counts_df, metadata=metadata) - with pytest.warns(RuntimeWarning): + with pytest.warns(UserWarning): dds.deseq2()