Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH implement local fit for trend curve #234

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pr_validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ jobs:
strategy:
matrix:
include:
- os: ubuntu-latest
python: "3.8"
- os: ubuntu-latest
python: "3.9"
- os: ubuntu-latest
Expand Down
266 changes: 202 additions & 64 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing as mp
import sys
import time
import warnings
Expand All @@ -21,6 +22,7 @@
from pydeseq2.preprocessing import deseq2_norm_transform
from pydeseq2.utils import build_design_matrix
from pydeseq2.utils import dispersion_trend
from pydeseq2.utils import local_trend_fit
from pydeseq2.utils import make_scatter
from pydeseq2.utils import mean_absolute_deviation
from pydeseq2.utils import n_or_more_replicates
Expand Down Expand Up @@ -71,6 +73,11 @@ class DeseqDataSet(ad.AnnData):
specifying the factor of interest and the reference (control) level against which
we're testing, e.g. ``["condition", "A"]``. (default: ``None``).

disp_function_type : str
Either "parametric", "local" or "mean", for the type of fitting of dispersions
trend curve.If "parametric" is selected but the fitting fails, it will switch to
"local". (default: ``"parametric"``).

min_mu : float
Threshold for mean estimates. (default: ``0.5``).

Expand Down Expand Up @@ -181,6 +188,7 @@ def __init__(
design_factors: Union[str, List[str]] = "condition",
continuous_factors: Optional[List[str]] = None,
ref_level: Optional[List[str]] = None,
disp_function_type: Literal["parametric", "local", "mean"] = "parametric",
min_mu: float = 0.5,
min_disp: float = 1e-8,
max_disp: float = 10.0,
Expand Down Expand Up @@ -266,6 +274,7 @@ def __init__(
# Check that the design matrix has full rank
self._check_full_rank_design()

self.disp_function_type = disp_function_type
self.min_mu = min_mu
self.min_disp = min_disp
self.max_disp = np.maximum(max_disp, self.n_obs)
Expand Down Expand Up @@ -343,7 +352,7 @@ def vst_fit(
if use_design:
# Check that the dispersion trend curve was fitted. If not, fit it.
# This will call previous functions in a cascade.
if "trend_coeffs" not in self.uns:
if "disp_function" not in self.uns:
self.fit_dispersion_trend()
else:
# Reduce the design matrix to an intercept and reconstruct at the end
Expand Down Expand Up @@ -477,7 +486,7 @@ def fit_size_factors(
warnings.warn(
"Every gene contains at least one zero, "
"cannot compute log geometric means. Switching to iterative mode.",
RuntimeWarning,
UserWarning,
stacklevel=2,
)
self._fit_iterate_size_factors()
Expand Down Expand Up @@ -567,72 +576,58 @@ def fit_genewise_dispersions(self) -> None:
self.varm["_genewise_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_

def fit_dispersion_trend(self) -> None:
r"""Fit the dispersion trend coefficients.
r"""Fit the dispersion trend curve.

:math:`f(\mu) = \alpha_1/\mu + a_0`.
Three methods are available, depending on the ``disp_function_type`` attribute:
"parametric", "local" and "mean".
"""
# Check that genewise dispersions are available. If not, compute them.
if "genewise_dispersions" not in self.varm:
self.fit_genewise_dispersions()

if not self.quiet:
print("Fitting dispersion trend curve...", file=sys.stderr)
start = time.time()
self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

# Exclude all-zero counts
targets = pd.Series(
self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(),
index=self.non_zero_genes,
)
covariates = pd.Series(
1 / self[:, self.non_zero_genes].varm["_normed_means"],
index=self.non_zero_genes,
)
if self.disp_function_type == "parametric":
try:
self._fit_parametric_trend()
except RuntimeError:
warnings.warn(
"The dispersion trend curve fitting did not converge. "
"Switching to a mean-based dispersion trend.",
UserWarning,
stacklevel=2,
)
self.disp_function_type = "local"

for gene in self.non_zero_genes:
if (
np.isinf(covariates.loc[gene]).any()
or np.isnan(covariates.loc[gene]).any()
):
targets.drop(labels=[gene], inplace=True)
covariates.drop(labels=[gene], inplace=True)
if (self.uns["trend_coeffs"] == 0).any():
warnings.warn(
f"self.disp_function_type={self.disp_function_type}, but the "
f"dispersion trend was not well captured by the function: "
f"y = a / x + b. Switching to local regression.",
UserWarning,
stacklevel=2,
)
self.disp_function_type = "local"
del self.uns["trend_coeffs"]

# Initialize coefficients
old_coeffs = pd.Series([0.1, 0.1])
coeffs = pd.Series([1.0, 1.0])
if self.disp_function_type == "local":
try:
self._fit_local_trend()
except (ValueError, RuntimeError):
print("Local trend fit did not converge, switching to mean fit.")
self.disp_function_type = "mean"

while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6:
old_coeffs = coeffs
coeffs, predictions = self.inference.dispersion_trend_gamma_glm(
covariates, targets
)
# Filter out genes that are too far away from the curve before refitting
pred_ratios = (
self[:, covariates.index].varm["genewise_dispersions"] / predictions
)
if self.disp_function_type == "mean":
self._fit_mean_trend()

targets.drop(
targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
covariates.drop(
covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
if self.disp_function_type not in ["parametric", "local", "mean"]:
raise NotImplementedError(
f"Unknown disp_function_type: {self.disp_function_type}. "
"Expected 'parametric', 'local' or 'mean'."
)

end = time.time()

if not self.quiet:
print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])

self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
self.varm["fitted_dispersions"][self.varm["non_zero"]] = dispersion_trend(
self.varm["_normed_means"][self.varm["non_zero"]],
self.uns["trend_coeffs"],
)
def disp_function(self, x):
"""Return the dispersion trend function at x."""
if self.disp_function_type == "parametric":
return dispersion_trend(x, self.uns["trend_coeffs"])
elif self.disp_function_type == "local":
return np.exp(self.uns["loess"].predict(np.log(x)).values)
elif self.disp_function_type == "mean":
return self.uns["mean_disp"]

def fit_dispersion_prior(self) -> None:
"""Fit dispersion variance priors and standard deviation of log-residuals.
Expand Down Expand Up @@ -859,6 +854,142 @@ def _fit_MoM_dispersions(self) -> None:
alpha_hat, self.min_disp, self.max_disp
)

def _fit_parametric_trend(self) -> None:
r"""Fit the dispersion trend coefficients.

:math:`f(\mu) = \alpha_1/\mu + a_0`.
"""
# Check that genewise dispersions are available. If not, compute them.
if "genewise_dispersions" not in self.varm:
self.fit_genewise_dispersions()

if not self.quiet:
print("Fitting dispersion trend curve...", file=sys.stderr)
start = time.time()
self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

# Exclude all-zero counts
targets = pd.Series(
self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(),
index=self.non_zero_genes,
)
covariates = pd.Series(
1 / self[:, self.non_zero_genes].varm["_normed_means"],
index=self.non_zero_genes,
)

for gene in self.non_zero_genes:
if (
np.isinf(covariates.loc[gene]).any()
or np.isnan(covariates.loc[gene]).any()
):
targets.drop(labels=[gene], inplace=True)
covariates.drop(labels=[gene], inplace=True)

# Initialize coefficients
old_coeffs = pd.Series([0.1, 0.1])
coeffs = pd.Series([1.0, 1.0])

try:
while (coeffs > 0).all() and (
np.log(np.abs(coeffs / old_coeffs)) ** 2
).sum() >= 1e-6:
old_coeffs = coeffs
coeffs, predictions = self.inference.dispersion_trend_gamma_glm(
covariates, targets
)
# Filter out genes that are too far away from the curve before refitting
pred_ratios = (
self[:, covariates.index].varm["genewise_dispersions"] / predictions
)

targets.drop(
targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
covariates.drop(
covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
except RuntimeError as e:
raise e

end = time.time()

if not self.quiet:
print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])

self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
self.varm["fitted_dispersions"][self.varm["non_zero"]] = self.disp_function(
self.varm["_normed_means"][self.varm["non_zero"]]
)

def _fit_local_trend(self) -> None:
r"""Fit the dispersion trend curve using local regression."""
# Check that genewise dispersions are available. If not, compute them.
warnings.warn(
"Running local trend fit will make the DeseqDataSet object unpicklable.",
UserWarning,
stacklevel=2,
)

if "genewise_dispersions" not in self.varm:
self.fit_genewise_dispersions()

if not self.quiet:
print("Fitting dispersion trend curve...", file=sys.stderr)
start = time.time()
self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

genes_to_fit = self.varm["non_zero"] & (
self.varm["genewise_dispersions"] >= 10 * self.min_disp
)

if len(genes_to_fit) == 0:
raise ValueError("No genes to fit: all dispersions are below 10 * min_disp")

means = self.varm["_normed_means"][genes_to_fit]
dispersions = self.varm["genewise_dispersions"][genes_to_fit]

# Run local trend fit in a separate process to avoid segmentation faults
q = mp.Manager().Queue()
p = mp.Process(target=local_trend_fit, args=(means, dispersions, q))
p.start()
p.join()

if p.exitcode != 0:
raise RuntimeError
else:
# The code rand without fault, so we can now run it locally.
# Returning the loess object in the first step would be more straightforward,
# But it is unfortunately unpicklable.

lo = local_trend_fit(means, dispersions)

end = time.time()

if not self.quiet:
print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

self.uns["loess"] = lo
self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
self.varm["fitted_dispersions"][self.varm["non_zero"]] = self.disp_function(
self.varm["_normed_means"][self.varm["non_zero"]]
)

def _fit_mean_trend(self):
"""Fit the dispersion trend curve using the mean of gene-wise dispersions."""
mean_disp = trim_mean(
self.varm["genewise_dispersions"][
self.varm["genewise_dispersions"] > 10 * self.min_disp
],
proportiontocut=0.001,
)

self.varm["fitted_dispersions"] = np.full(self.n_vars, mean_disp)

def plot_dispersions(
self, log: bool = True, save_path: Optional[str] = None, **kwargs
) -> None:
Expand Down Expand Up @@ -988,6 +1119,7 @@ def _refit_without_outliers(
min_replicates=self.min_replicates,
beta_tol=self.beta_tol,
inference=self.inference,
disp_function_type=self.disp_function_type,
)

# Use the same size factors
Expand All @@ -1001,11 +1133,10 @@ def _refit_without_outliers(

# Compute trend dispersions.
# Note: the trend curve is not refitted.
sub_dds.uns["trend_coeffs"] = self.uns["trend_coeffs"]
sub_dds.varm["_normed_means"] = sub_dds.layers["normed_counts"].mean(0)
sub_dds.varm["fitted_dispersions"] = dispersion_trend(
sub_dds.varm["_normed_means"],
sub_dds.uns["trend_coeffs"],

sub_dds.varm["fitted_dispersions"] = self.disp_function(
sub_dds.varm["_normed_means"][sub_dds.varm["non_zero"]]
)

# Estimate MAP dispersions.
Expand Down Expand Up @@ -1089,6 +1220,13 @@ def objective(p):
& self.varm["non_zero"]
]

if len(use_for_mean_genes) == 0:
print(
"No genes have a dispersion above 10 * min_disp in "
"_fit_iterate_size_factors."
)
break

mean_disp = trimmed_mean(
self[:, use_for_mean_genes].varm["genewise_dispersions"], trim=0.001
)
Expand Down
Loading
Loading