Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH implement local fit for trend curve #234

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: implementing local fit for trend curve (WIP)
  • Loading branch information
BorisMuzellec committed Feb 2, 2024
commit 1055aea07687c1e617a5587834abe39598196aba
229 changes: 159 additions & 70 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from scipy.stats import f # type: ignore
from scipy.stats import trim_mean # type: ignore

from skmisc.loess import loess

from pydeseq2.default_inference import DefaultInference
from pydeseq2.inference import Inference
from pydeseq2.preprocessing import deseq2_norm_fit
Expand Down Expand Up @@ -71,6 +73,11 @@ class DeseqDataSet(ad.AnnData):
specifying the factor of interest and the reference (control) level against which
we're testing, e.g. ``["condition", "A"]``. (default: ``None``).

trend_fit_type : str
Either "parametric" or "local" for the type of fitting of dispersions trend
curve.If "parametric" is selected but the fitting fails, it will switch to
"local". (default: ``"parametric"``).

min_mu : float
Threshold for mean estimates. (default: ``0.5``).

Expand Down Expand Up @@ -181,6 +188,7 @@ def __init__(
design_factors: Union[str, List[str]] = "condition",
continuous_factors: Optional[List[str]] = None,
ref_level: Optional[List[str]] = None,
trend_fit_type: Literal["parametric", "local"] = "parametric",
min_mu: float = 0.5,
min_disp: float = 1e-8,
max_disp: float = 10.0,
Expand Down Expand Up @@ -266,6 +274,7 @@ def __init__(
# Check that the design matrix has full rank
self._check_full_rank_design()

self.trend_fit_type = trend_fit_type
self.min_mu = min_mu
self.min_disp = min_disp
self.max_disp = np.maximum(max_disp, self.n_obs)
Expand Down Expand Up @@ -568,71 +577,25 @@ def fit_genewise_dispersions(self) -> None:

def fit_dispersion_trend(self) -> None:
r"""Fit the dispersion trend coefficients.

:math:`f(\mu) = \alpha_1/\mu + a_0`.
TODO
"""
# Check that genewise dispersions are available. If not, compute them.
if "genewise_dispersions" not in self.varm:
self.fit_genewise_dispersions()

if not self.quiet:
print("Fitting dispersion trend curve...", file=sys.stderr)
start = time.time()
self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

# Exclude all-zero counts
targets = pd.Series(
self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(),
index=self.non_zero_genes,
)
covariates = pd.Series(
1 / self[:, self.non_zero_genes].varm["_normed_means"],
index=self.non_zero_genes,
)

for gene in self.non_zero_genes:
if (
np.isinf(covariates.loc[gene]).any()
or np.isnan(covariates.loc[gene]).any()
):
targets.drop(labels=[gene], inplace=True)
covariates.drop(labels=[gene], inplace=True)

# Initialize coefficients
old_coeffs = pd.Series([0.1, 0.1])
coeffs = pd.Series([1.0, 1.0])

while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6:
old_coeffs = coeffs
coeffs, predictions = self.inference.dispersion_trend_gamma_glm(
covariates, targets
)
# Filter out genes that are too far away from the curve before refitting
pred_ratios = (
self[:, covariates.index].varm["genewise_dispersions"] / predictions
)

targets.drop(
targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
covariates.drop(
covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
if self.trend_fit_type == "parametric":
self._fit_parametric_trend()

end = time.time()

if not self.quiet:
print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])
if (self.uns["trend_coeffs"] <= 0).any():
warnings.warn(
f"self.trend_fit_type={self.trend_fit_type}, but the dispersion trend was"
f" not well captured by the function: y = a / x + b. Switchiing to local "
f"regression.",
UserWarning,
stacklevel=2,
)
self.trend_fit_type = "local"
del self.uns["trend_coeffs"]

self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
self.varm["fitted_dispersions"][self.varm["non_zero"]] = dispersion_trend(
self.varm["_normed_means"][self.varm["non_zero"]],
self.uns["trend_coeffs"],
)
if self.trend_fit_type == "local":
self._fit_local_trend()

def fit_dispersion_prior(self) -> None:
"""Fit dispersion variance priors and standard deviation of log-residuals.
Expand Down Expand Up @@ -721,9 +684,11 @@ def fit_MAP_dispersions(self) -> None:

# Filter outlier genes for which we won't apply shrinkage
self.varm["dispersions"] = self.varm["MAP_dispersions"].copy()
self.varm["_outlier_genes"] = np.log(self.varm["genewise_dispersions"]) > np.log(
self.varm["fitted_dispersions"]
) + 2 * np.sqrt(self.uns["_squared_logres"])
self.varm["_outlier_genes"] = np.log(
self.varm["genewise_dispersions"]
) > np.log(self.varm["fitted_dispersions"]) + 2 * np.sqrt(
self.uns["_squared_logres"]
)
self.varm["dispersions"][self.varm["_outlier_genes"]] = self.varm[
"genewise_dispersions"
][self.varm["_outlier_genes"]]
Expand Down Expand Up @@ -859,6 +824,115 @@ def _fit_MoM_dispersions(self) -> None:
alpha_hat, self.min_disp, self.max_disp
)

def _fit_parametric_trend(self) -> None:
r"""Fit the dispersion trend coefficients.

:math:`f(\mu) = \alpha_1/\mu + a_0`.
"""
# Check that genewise dispersions are available. If not, compute them.
if "genewise_dispersions" not in self.varm:
self.fit_genewise_dispersions()

if not self.quiet:
print("Fitting dispersion trend curve...", file=sys.stderr)
start = time.time()
self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

# Exclude all-zero counts
targets = pd.Series(
self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(),
index=self.non_zero_genes,
)
covariates = pd.Series(
1 / self[:, self.non_zero_genes].varm["_normed_means"],
index=self.non_zero_genes,
)

for gene in self.non_zero_genes:
if (
np.isinf(covariates.loc[gene]).any()
or np.isnan(covariates.loc[gene]).any()
):
targets.drop(labels=[gene], inplace=True)
covariates.drop(labels=[gene], inplace=True)

# Initialize coefficients
old_coeffs = pd.Series([0.1, 0.1])
coeffs = pd.Series([1.0, 1.0])

while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6:
old_coeffs = coeffs
coeffs, predictions = self.inference.dispersion_trend_gamma_glm(
covariates, targets
)
# Filter out genes that are too far away from the curve before refitting
pred_ratios = (
self[:, covariates.index].varm["genewise_dispersions"] / predictions
)

targets.drop(
targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
covariates.drop(
covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)

end = time.time()

if not self.quiet:
print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])

self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
self.varm["fitted_dispersions"][self.varm["non_zero"]] = dispersion_trend(
self.varm["_normed_means"][self.varm["non_zero"]],
self.uns["trend_coeffs"],
)

def _fit_local_trend(self) -> None:
r"""Fit the dispersion trend coefficients.

TODO
"""
# Check that genewise dispersions are available. If not, compute them.
if "genewise_dispersions" not in self.varm:
self.fit_genewise_dispersions()

if not self.quiet:
print("Fitting dispersion trend curve...", file=sys.stderr)
start = time.time()
self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

genes_to_fit = self.varm["non_zero"] & (
self.varm["genewise_dispersions"] >= 10 * self.min_disp
)

if len(genes_to_fit) == 0:
raise ValueError("No genes to fit: all dispersions are below 10 * min_disp")

lo = loess(
x=np.log(self.varm["_normed_means"][genes_to_fit]),
y=np.log(self.varm["genewise_dispersions"][genes_to_fit]),
weights=self.varm["_normed_means"][genes_to_fit],
surface="direct", # to allow extrapolation
)

lo.fit()

end = time.time()

if not self.quiet:
print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

self.uns["disp_function"] = lambda x: np.exp(lo.predict(np.log(x)).values)
self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
self.varm["fitted_dispersions"][self.varm["non_zero"]] = self.uns[
"disp_function"
](self.varm["_normed_means"][self.varm["non_zero"]])

def plot_dispersions(
self, log: bool = True, save_path: Optional[str] = None, **kwargs
) -> None:
Expand Down Expand Up @@ -1001,12 +1075,25 @@ def _refit_without_outliers(

# Compute trend dispersions.
# Note: the trend curve is not refitted.
sub_dds.uns["trend_coeffs"] = self.uns["trend_coeffs"]
sub_dds.varm["_normed_means"] = sub_dds.layers["normed_counts"].mean(0)
sub_dds.varm["fitted_dispersions"] = dispersion_trend(
sub_dds.varm["_normed_means"],
sub_dds.uns["trend_coeffs"],
)

if self.trend_fit_type == "parametric":
sub_dds.uns["trend_coeffs"] = self.uns["trend_coeffs"]
sub_dds.varm["fitted_dispersions"] = dispersion_trend(
sub_dds.varm["_normed_means"],
sub_dds.uns["trend_coeffs"],
)
elif self.trend_fit_type == "local":
sub_dds.uns["disp_function"] = self.uns["disp_function"]
sub_dds.varm["fitted_dispersions"] = self.uns["disp_function"](
sub_dds.varm["_normed_means"][sub_dds.varm["non_zero"]]
)

else:
raise AttributeError(
f"Found trend_fit_type '{self.trend_fit_type}'. Expected 'parametric' or "
"'local'."
)

# Estimate MAP dispersions.
# Note: the prior variance is not recomputed.
Expand Down Expand Up @@ -1111,7 +1198,9 @@ def objective(p):
) < 1e-4:
break
elif i == niter - 1:
print("Iterative size factor fitting did not converge.", file=sys.stderr)
print(
"Iterative size factor fitting did not converge.", file=sys.stderr
)

# Restore the design matrix and free buffer
self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
Expand Down
Loading