Skip to content

Commit

Permalink
Endelman-Jannink GRM #1062
Browse files Browse the repository at this point in the history
  • Loading branch information
timothymillar authored and mergify[bot] committed Jul 14, 2023
1 parent 07e75c2 commit 919d3a5
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 48 deletions.
142 changes: 108 additions & 34 deletions sgkit/stats/grm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Hashable, Optional
from typing import Hashable, Optional, Tuple

import dask.array as da
import numpy as np
Expand All @@ -9,13 +9,17 @@
from sgkit.typing import ArrayLike
from sgkit.utils import conditional_merge_datasets, create_dataset

EST_VAN_RADEN = "VanRaden"
EST_ENDELMAN_JANNINK = "Endelman-Jannink"
EST_MARTINI = "Martini"


def _grm_VanRaden(
call_dosage: ArrayLike,
ancestral_frequency: ArrayLike,
ploidy: int,
skipna: bool = False,
):
) -> ArrayLike:
ancestral_dosage = ancestral_frequency * ploidy
M = call_dosage - ancestral_dosage[:, None]
if skipna:
Expand All @@ -32,18 +36,69 @@ def _grm_VanRaden(
return G


def _shrinkage_Ledoit_Wolf(X: ArrayLike) -> Tuple[ArrayLike, float]:
# shrinkage estimator of Ledoit & Wolf 2004
# following the notation of Endelman & Jannink
# X is an m * n matrix with mean centered columns
m, n = X.shape
S = (X.T @ X) / m
T = da.diag(S).mean() * da.eye(n)
# calculate scaling parameter delta
X2 = X**2
Gamma = X2.T @ X2 / m # E&J eq 25
delta_numerator = (Gamma - S**2).sum() / m # Error in E&J eq 24?
delta_denominator = ((S - T) ** 2).sum() # squared Frobenius norm
delta = delta_numerator / delta_denominator # E&J eq 20
# delta must be in [0 1]
delta = da.maximum(delta, 0.0)
delta = da.minimum(delta, 1.0)
Cov = delta * T + (1 - delta) * S # E&J eq 16
return Cov, delta


def _grm_Endelman_Jannink(
call_dosage: ArrayLike,
ancestral_frequency: ArrayLike,
ploidy: int,
skipna: bool = False,
) -> ArrayLike:
if skipna:
raise NotImplementedError(
f"The 'skipna' option is not implemented for the '{EST_ENDELMAN_JANNINK}' estimator"
)
ancestral_dosage = ancestral_frequency * ploidy
W = call_dosage - ancestral_dosage[:, None]
W_mean = da.nanmean(W, axis=0, keepdims=True)
X = W - W_mean # mean centered
Cov, _ = _shrinkage_Ledoit_Wolf(X)
# E&J eq 17
numerator = Cov + W_mean.T @ W_mean
denominator = ploidy * da.mean(ancestral_frequency * (1 - ancestral_frequency))
G = numerator / denominator
return G


def genomic_relationship(
ds: Dataset,
*,
call_dosage: Hashable = variables.call_dosage,
estimator: Optional[Literal["VanRaden"]] = None,
estimator: Optional[Literal[EST_VAN_RADEN, EST_ENDELMAN_JANNINK]] = None, # type: ignore
ancestral_frequency: Optional[Hashable] = None,
ploidy: Optional[int] = None,
skipna: bool = False,
merge: bool = True,
) -> Dataset:
"""Compute a genomic relationship matrix (AKA the GRM or G-matrix).
The following estimators are supported:
* **VanRaden**: the first estimator described by VanRaden 2008 [1] and
generalized to autopolyploids by Ashraf et al 2016 [2] and Bilton 2020 [3].
* **Endelman-Jannink**: a shrinkage estimator described by Endelman and
Jannick 2012 [4]. This is based on the VanRaden estimator and aims to
improve the accuracy of estimated breeding values with low-density markers.
Parameters
----------
ds
Expand All @@ -52,10 +107,8 @@ def genomic_relationship(
Input variable name holding call_dosage as defined by
:data:`sgkit.variables.call_dosage_spec`.
estimator
Specifies a relatedness estimator to use. Currently the only option
is ``"VanRaden"`` which uses the method described by VanRaden 2008 [1]
and generalized to autopolyploids by Ashraf et al 2016 [2] and
Bilton 2020 [3].
Specifies the relatedness estimator to use. Must be one of
``'VanRaden'`` (the default) or ``'Endelman-Jannink'``.
ancestral_frequency
Frequency of variant alleles corresponding to call_dosage within
the ancestral/base/reference population.
Expand Down Expand Up @@ -83,10 +136,6 @@ def genomic_relationship(
which is a matrix of pairwise relationships among all samples.
The dimensions are named ``samples_0`` and ``samples_1``.
Warnings
--------
This function is only applicable to fixed-ploidy, biallelic datasets.
Raises
------
ValueError
Expand All @@ -98,6 +147,17 @@ def genomic_relationship(
ValueError
If ancestral_frequency is the incorrect shape.
Note
----
The shrinkage parameter for the Endelman-Jannink estimator depends upon
the total number of variants in the dataset. Monomorphic variants need
to be removed from the dataset in order for the resulting estimates to
match those found using the default settings in rrBLUP [5].
Warnings
--------
This function is only applicable to fixed-ploidy, biallelic datasets.
Examples
--------
Expand Down Expand Up @@ -217,31 +277,45 @@ def genomic_relationship(
"Developing statistical methods for genetic analysis of genotypes from
genotyping-by-sequencing data"
PhD thesis, University of Otago.
[4] - J. B. Endelman and J. -L. Jannink 2012.
"Shrinkage Estimation of the Realized Relationship Matrix"
G3 2: 1405-1413.
[5] - J. B. Endelman
"Ridge regression and other kernels for genomic selection with R package"
Plant Genome 4: 250-255.
"""
variables.validate(
ds,
{call_dosage: variables.call_dosage_spec},
)

estimator = estimator or "VanRaden"
if estimator not in {"VanRaden"}:
raise ValueError("Unknown estimator '{}'".format(estimator))
estimator = estimator or EST_VAN_RADEN
# TODO: raise on mixed ploidy
ploidy = ploidy or ds.dims.get("ploidy")
if ploidy is None:
raise ValueError("Ploidy must be specified when the ploidy dimension is absent")

# VanRaden GRM
cd = da.array(ds[call_dosage].data)
n_variants, _ = cd.shape
if ancestral_frequency is None:
raise ValueError("The 'VanRaden' estimator requires ancestral_frequency")
af = da.array(ds[ancestral_frequency].data)
if af.shape != (n_variants,):
raise ValueError(
"The ancestral_frequency variable must have one value per variant"
)
G = _grm_VanRaden(cd, af, ploidy=ploidy, skipna=skipna)
dosage = da.array(ds[call_dosage].data)

# estimators requiring 'ancestral_frequency'
if estimator in {EST_VAN_RADEN, EST_ENDELMAN_JANNINK}:
n_variants, _ = dosage.shape
if ancestral_frequency is None:
raise ValueError(
f"The '{estimator}' estimator requires the 'ancestral_frequency' argument"
)
af = da.array(ds[ancestral_frequency].data)
if af.shape != (n_variants,):
raise ValueError(
"The ancestral_frequency variable must have one value per variant"
)
if estimator == EST_VAN_RADEN:
G = _grm_VanRaden(dosage, af, ploidy=ploidy, skipna=skipna)
elif estimator == EST_ENDELMAN_JANNINK:
G = _grm_Endelman_Jannink(dosage, af, ploidy=ploidy, skipna=skipna)
else:
raise ValueError(f"Unknown estimator '{estimator}'")

new_ds = create_dataset(
{
Expand Down Expand Up @@ -451,7 +525,7 @@ def hybrid_relationship(
pedigree_relationship: Hashable = None,
pedigree_subset_inverse_relationship: Hashable = None,
genomic_sample: Optional[Hashable] = None,
estimator: Optional[Literal["Martini"]] = None,
estimator: Optional[Literal[EST_MARTINI]] = None, # type: ignore
tau: float = 1.0,
omega: float = 1.0,
merge: bool = True,
Expand Down Expand Up @@ -585,9 +659,9 @@ def hybrid_relationship(
Journal of Dairy Science 93 (2): 743-752.
"""
if estimator is None:
estimator = "Martini"
if estimator not in {"Martini"}:
raise ValueError("Unknown estimator '{}'".format(estimator))
estimator = EST_MARTINI
if estimator not in {EST_MARTINI}:
raise ValueError(f"Unknown estimator '{estimator}'")
variables.validate(
ds,
{
Expand Down Expand Up @@ -697,7 +771,7 @@ def hybrid_inverse_relationship(
pedigree_inverse_relationship: Hashable,
pedigree_subset_inverse_relationship: Hashable = None,
genomic_sample: Optional[Hashable] = None,
estimator: Optional[Literal["Martini"]] = None,
estimator: Optional[Literal[EST_MARTINI]] = None, # type: ignore
tau: float = 1.0,
omega: float = 1.0,
merge: bool = True,
Expand Down Expand Up @@ -847,9 +921,9 @@ def hybrid_inverse_relationship(
Journal of Dairy Science 93 (2): 743-752.
"""
if estimator is None:
estimator = "Martini"
if estimator not in {"Martini"}:
raise ValueError("Unknown estimator '{}'".format(estimator))
estimator = EST_MARTINI
if estimator not in {EST_MARTINI}:
raise ValueError(f"Unknown estimator '{estimator}'")
variables.validate(
ds,
{
Expand Down
Loading

0 comments on commit 919d3a5

Please sign in to comment.