diff --git a/docs/changelog.rst b/docs/changelog.rst index d4d408dcb..382497819 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,6 +19,8 @@ New Features (:user:`timothymillar`, :pr:`1100`, :issue:`1062`) - Add :func:`display_pedigree` function. (:user:`timothymillar`, :pr:`1104`, :issue:`1097`) +- Add option to count variant alleles directly from call genotypes in function :func:`count_variant_alleles`. + (:user:`timothymillar`, :pr:`1119`, :issue:`1116`) .. Breaking changes .. ~~~~~~~~~~~~~~~~ @@ -26,8 +28,11 @@ New Features .. Deprecations .. ~~~~~~~~~~~~ -.. Improvements -.. ~~~~~~~~~~~~ +Improvements +~~~~~~~~~~~~ + +- Improve performance of :func:`variant_stats` and :func:`sample_stats` functions. + (:user:`timothymillar`, :pr:`1119`, :issue:`1116`) .. Bug fixes .. ~~~~~~~~~ diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 36fdd6c42..bf09df2b9 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -203,7 +203,7 @@ shows how it can be used in the context of doing something simple like counting # Now the result is correct -- only the third sample is heterozygous so the count should be 1. # This how many sgkit functions handle missing data internally: - sg.variant_stats(ds).variant_n_het.item(0) + sg.variant_stats(ds).variant_n_het.values.item(0) Windowing --------- @@ -320,8 +320,8 @@ Xarray and Pandas operations in a single pipeline: # for windows of size 20 variants ( ds - # Add call rate and other statistics - .pipe(sg.variant_stats) + # Add and compute call rate and other statistics + .pipe(sg.variant_stats).compute() # Apply filter to include variants present across > 80% of samples .pipe(lambda ds: ds.sel(variants=ds.variant_call_rate > .8)) # Create windows of size 20 variants diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index c380a1f95..f6b626815 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Hashable +from typing import Hashable import dask.array as da import numpy as np @@ -96,7 +96,9 @@ def count_call_alleles( def count_variant_alleles( ds: Dataset, *, + call_genotype: Hashable = variables.call_genotype, call_allele_count: Hashable = variables.call_allele_count, + using: Literal[variables.call_allele_count, variables.call_genotype] = variables.call_allele_count, # type: ignore merge: bool = True, ) -> Dataset: """Compute allele count from per-sample allele counts, or genotype calls. @@ -105,11 +107,22 @@ def count_variant_alleles( ---------- ds Dataset containing genotype calls. + call_genotype + Input variable name holding call_genotype as defined by + :data:`sgkit.variables.call_genotype_spec`. + This variable is only used if specified by the 'using' argument. call_allele_count Input variable name holding call_allele_count as defined by :data:`sgkit.variables.call_allele_count_spec`. + This variable is only used if specified by the 'using' argument. If the variable is not present in ``ds``, it will be computed using :func:`count_call_alleles`. + using + specify the variable used to calculate allele counts from. + If ``'call_allele_count'`` (the default), the result will + be calculated from the call_allele_count variable. + If ``'call_genotype'``, the result will be calculated from + the call_genotype variable. merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -122,6 +135,12 @@ def count_variant_alleles( of allele counts with shape (variants, alleles) and values corresponding to the number of non-missing occurrences of each allele. + Note + ---- + This method is more efficient when calculating allele counts directly from + the call_genotype variable unless the call_allele_count variable has already + been (or will be) calculated. + Examples -------- @@ -141,14 +160,28 @@ def count_variant_alleles( [2, 2], [4, 0]], dtype=uint64) """ - ds = define_variable_if_absent( - ds, variables.call_allele_count, call_allele_count, count_call_alleles - ) - variables.validate(ds, {call_allele_count: variables.call_allele_count_spec}) - - new_ds = create_dataset( - {variables.variant_allele_count: ds[call_allele_count].sum(dim="samples")} - ) + if using == variables.call_allele_count: + ds = define_variable_if_absent( + ds, variables.call_allele_count, call_allele_count, count_call_alleles + ) + variables.validate(ds, {call_allele_count: variables.call_allele_count_spec}) + AC = ds[call_allele_count].sum(dim="samples") + elif using == variables.call_genotype: + from .aggregation_numba_fns import count_alleles + + variables.validate(ds, {call_genotype: variables.call_genotype_spec}) + n_alleles = ds.dims["alleles"] + n_variant = ds.dims["variants"] + G = da.asarray(ds[call_genotype]).reshape((n_variant, -1)) + shape = (G.chunks[0], n_alleles) + # use uint64 dummy array to return uin64 counts array + N = np.empty(n_alleles, dtype=np.uint64) + AC = da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=1, new_axis=1) + AC = xr.DataArray(AC, dims=["variants", "alleles"]) + else: + options = {variables.call_genotype, variables.call_allele_count} + raise ValueError(f"The 'using' argument must be one of {options}.") + new_ds = create_dataset({variables.variant_allele_count: AC}) return conditional_merge_datasets(ds, new_ds, merge) @@ -237,18 +270,6 @@ def count_cohort_alleles( return conditional_merge_datasets(ds, new_ds, merge) -def _swap(dim: Dimension) -> Dimension: - return "samples" if dim == "variants" else "variants" - - -def call_rate(ds: Dataset, dim: Dimension, call_genotype_mask: Hashable) -> Dataset: - odim = _swap(dim)[:-1] - n_called = (~ds[call_genotype_mask].any(dim="ploidy")).sum(dim=dim) - return create_dataset( - {f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]} - ) - - def count_variant_genotypes( ds: Dataset, *, @@ -432,39 +453,6 @@ def genotype_coords( return ds -def count_genotypes( - ds: Dataset, - dim: Dimension, - call_genotype: Hashable = variables.call_genotype, - call_genotype_mask: Hashable = variables.call_genotype_mask, - merge: bool = True, -) -> Dataset: - variables.validate( - ds, - { - call_genotype_mask: variables.call_genotype_mask_spec, - call_genotype: variables.call_genotype_spec, - }, - ) - odim = _swap(dim)[:-1] - M, G = ds[call_genotype_mask].any(dim="ploidy"), ds[call_genotype] - n_hom_ref = (G == 0).all(dim="ploidy") - n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy") - n_non_ref = (G > 0).any(dim="ploidy") - n_het = ~(n_hom_alt | n_hom_ref) - # This would 0 out the `het` case with any missing calls - agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call] - new_ds = create_dataset( - { - f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call] - f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call] - f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call] - f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call] - } - ) - return conditional_merge_datasets(ds, new_ds, merge) - - def call_allele_frequencies( ds: Dataset, *, @@ -601,35 +589,9 @@ def cohort_allele_frequencies( return conditional_merge_datasets(ds, new_ds, merge) -def allele_frequency( - ds: Dataset, - call_genotype_mask: Hashable, - variant_allele_count: Hashable, -) -> Dataset: - data_vars: Dict[Hashable, Any] = {} - # only compute variant allele count if not already in dataset - if variant_allele_count in ds: - variables.validate( - ds, {variant_allele_count: variables.variant_allele_count_spec} - ) - AC = ds[variant_allele_count] - else: - AC = count_variant_alleles(ds, merge=False)[variables.variant_allele_count] - data_vars[variables.variant_allele_count] = AC - - M = ds[call_genotype_mask].stack(calls=("samples", "ploidy")) - AN = (~M).sum(dim="calls") - assert AN.shape == (ds.dims["variants"],) - - data_vars[variables.variant_allele_total] = AN - data_vars[variables.variant_allele_frequency] = AC / AN - return create_dataset(data_vars) - - def variant_stats( ds: Dataset, *, - call_genotype_mask: Hashable = variables.call_genotype_mask, call_genotype: Hashable = variables.call_genotype, variant_allele_count: Hashable = variables.variant_allele_count, merge: bool = True, @@ -644,10 +606,6 @@ def variant_stats( Input variable name holding call_genotype. Defined by :data:`sgkit.variables.call_genotype_spec`. Must be present in ``ds``. - call_genotype_mask - Input variable name holding call_genotype_mask. - Defined by :data:`sgkit.variables.call_genotype_mask_spec` - Must be present in ``ds``. variant_allele_count Input variable name holding variant_allele_count, as defined by :data:`sgkit.variables.variant_allele_count_spec`. @@ -681,40 +639,92 @@ def variant_stats( The number of occurrences of all alleles. - :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles): The frequency of occurrence of each allele. + + Note + ---- + If the dataset contains partial genotype calls (i.e., genotype calls with + a mixture of called and missing alleles), these genotypes will be ignored + when counting the number of homozygous, heterozygous or total genotype calls. + However, the called alleles will be counted when calculating allele counts + and frequencies using :func:`count_variant_alleles`. + + Note + ---- + When used on autopolyploid genotypes, this method treats genotypes calls + with any level of heterozygosity as 'heterozygous'. Only fully homozygous + genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'. + + Warnings + -------- + This method does not support mixed-ploidy datasets. + + Raises + ------ + ValueError + If the dataset contains mixed-ploidy genotype calls. + + See Also + -------- + :func:`count_variant_genotypes` """ - variables.validate( + from .aggregation_numba_fns import count_hom + + variables.validate(ds, {call_genotype: variables.call_genotype_spec}) + mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) + if mixed_ploidy: + raise ValueError("Mixed-ploidy dataset") + AC = define_variable_if_absent( ds, - { - call_genotype: variables.call_genotype_spec, - call_genotype_mask: variables.call_genotype_mask_spec, - }, + variables.variant_allele_count, + variant_allele_count, + count_variant_alleles, + using=variables.call_genotype, # improved performance + merge=False, + )[variant_allele_count] + G = da.array(ds[call_genotype].data) + H = xr.DataArray( + da.map_blocks( + count_hom, + G, + np.zeros(3, np.uint64), + drop_axis=(1, 2), + new_axis=1, + dtype=np.int64, + chunks=(G.chunks[0], 3), + ), + dims=["variants", "categories"], ) - new_ds = xr.merge( - [ - call_rate(ds, dim="samples", call_genotype_mask=call_genotype_mask), - count_genotypes( - ds, - dim="samples", - call_genotype=call_genotype, - call_genotype_mask=call_genotype_mask, - merge=False, - ), - allele_frequency( - ds, - call_genotype_mask=call_genotype_mask, - variant_allele_count=variant_allele_count, - ), - ] + _, n_sample, _ = G.shape + n_called = H.sum(axis=-1) + call_rate = n_called / n_sample + n_hom_ref = H[:, 0] + n_hom_alt = H[:, 1] + n_het = H[:, 2] + n_non_ref = n_called - n_hom_ref + allele_total = AC.sum(axis=-1).astype(int) # backwards compatibility + new_ds = xr.Dataset( + { + variables.variant_n_called: n_called, + variables.variant_call_rate: call_rate, + variables.variant_n_het: n_het, + variables.variant_n_hom_ref: n_hom_ref, + variables.variant_n_hom_alt: n_hom_alt, + variables.variant_n_non_ref: n_non_ref, + variables.variant_allele_count: AC, + variables.variant_allele_total: allele_total, + variables.variant_allele_frequency: AC / allele_total, + } ) + # for backwards compatible behavior + if (variant_allele_count in ds) and merge: + new_ds = new_ds.drop_vars(variant_allele_count) return conditional_merge_datasets(ds, variables.validate(new_ds), merge) def sample_stats( ds: Dataset, *, - call_genotype_mask: Hashable = variables.call_genotype_mask, call_genotype: Hashable = variables.call_genotype, - variant_allele_count: Hashable = variables.variant_allele_count, merge: bool = True, ) -> Dataset: """Compute quality control sample statistics from genotype calls. @@ -727,15 +737,6 @@ def sample_stats( Input variable name holding call_genotype. Defined by :data:`sgkit.variables.call_genotype_spec`. Must be present in ``ds``. - call_genotype_mask - Input variable name holding call_genotype_mask. - Defined by :data:`sgkit.variables.call_genotype_mask_spec` - Must be present in ``ds``. - variant_allele_count - Input variable name holding variant_allele_count, - as defined by :data:`sgkit.variables.variant_allele_count_spec`. - If the variable is not present in ``ds``, it will be computed - using :func:`count_variant_alleles`. merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -758,25 +759,63 @@ def sample_stats( The number of variants with homozygous alternate calls. - :data:`sgkit.variables.sample_n_non_ref_spec` (samples): The number of variants that are not homozygous reference calls. + + Note + ---- + If the dataset contains partial genotype calls (i.e., genotype calls with + a mixture of called and missing alleles), these genotypes will be ignored + when counting the number of homozygous, heterozygous or total genotype calls. + + Note + ---- + When used on autopolyploid genotypes, this method treats genotypes calls + with any level of heterozygosity as 'heterozygous'. Only fully homozygous + genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'. + + Warnings + -------- + This method does not support mixed-ploidy datasets. + + Raises + ------ + ValueError + If the dataset contains mixed-ploidy genotype calls. """ - variables.validate( - ds, - { - call_genotype: variables.call_genotype_spec, - call_genotype_mask: variables.call_genotype_mask_spec, - }, + from .aggregation_numba_fns import count_hom + + variables.validate(ds, {call_genotype: variables.call_genotype_spec}) + mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) + if mixed_ploidy: + raise ValueError("Mixed-ploidy dataset") + G = da.array(ds[call_genotype].data) + H = xr.DataArray( + da.map_blocks( + count_hom, + G.transpose(1, 0, 2), + np.zeros(3, np.uint64), + drop_axis=(1, 2), + new_axis=1, + dtype=np.int64, + chunks=(G.chunks[1], 3), + ), + dims=["samples", "categories"], ) - new_ds = xr.merge( - [ - call_rate(ds, dim="variants", call_genotype_mask=call_genotype_mask), - count_genotypes( - ds, - dim="variants", - call_genotype=call_genotype, - call_genotype_mask=call_genotype_mask, - merge=False, - ), - ] + n_variant, _, _ = G.shape + n_called = H.sum(axis=-1) + call_rate = n_called / n_variant + n_hom_ref = H[:, 0] + n_hom_alt = H[:, 1] + n_het = H[:, 2] + n_non_ref = n_called - n_hom_ref + new_ds = xr.Dataset( + { + variables.sample_n_called: n_called, + variables.sample_call_rate: call_rate, + variables.sample_n_het: n_het, + variables.sample_n_hom_ref: n_hom_ref, + variables.sample_n_hom_alt: n_hom_alt, + variables.sample_n_non_ref: n_non_ref, + } ) return conditional_merge_datasets(ds, variables.validate(new_ds), merge) diff --git a/sgkit/stats/aggregation_numba_fns.py b/sgkit/stats/aggregation_numba_fns.py index e8a6f92e2..3335f5457 100644 --- a/sgkit/stats/aggregation_numba_fns.py +++ b/sgkit/stats/aggregation_numba_fns.py @@ -2,7 +2,7 @@ # in a separate file here, and imported dynamically to avoid # initial compilation overhead. -from sgkit.accelerate import numba_guvectorize +from sgkit.accelerate import numba_guvectorize, numba_jit from sgkit.typing import ArrayLike @@ -12,6 +12,10 @@ "void(int16[:], uint8[:], uint8[:])", "void(int32[:], uint8[:], uint8[:])", "void(int64[:], uint8[:], uint8[:])", + "void(int8[:], uint64[:], uint64[:])", + "void(int16[:], uint64[:], uint64[:])", + "void(int32[:], uint64[:], uint64[:])", + "void(int64[:], uint64[:], uint64[:])", ], "(k),(n)->(n)", ) @@ -26,9 +30,10 @@ def count_alleles( Genotype call of shape (ploidy,) containing alleles encoded as type `int` with values < 0 indicating a missing allele. _ - Dummy variable of type `uint8` and shape (alleles,) used to - define the number of unique alleles to be counted in the - return value. + Dummy variable of type `uint8` or `uint64` and shape (alleles,) + used to define the number of unique alleles to be counted in the + return value. The dtype of this array determines the dtype of the + returned array. Returns ------- @@ -43,3 +48,57 @@ def count_alleles( a = g[i] if a >= 0: out[a] += 1 + + +@numba_jit(nogil=True) +def _classify_hom(genotype: ArrayLike) -> int: # pragma: no cover + a0 = genotype[0] + cat = min(a0, 1) # -1, 0, 1 + for i in range(1, len(genotype)): + if cat < 0: + break + a = genotype[i] + if a != a0: + cat = 2 # het + if a < 0: + cat = -1 + return cat + + +@numba_guvectorize( # type: ignore + [ + "void(int8[:,:], uint64[:], int64[:])", + "void(int16[:,:], uint64[:], int64[:])", + "void(int32[:,:], uint64[:], int64[:])", + "void(int64[:,:], uint64[:], int64[:])", + ], + "(n, k),(c)->(c)", +) +def count_hom( + genotypes: ArrayLike, _: ArrayLike, out: ArrayLike +) -> None: # pragma: no cover + """Generalized U-function for counting homozygous and heterozygous genotypes. + + Parameters + ---------- + g + Genotype call of shape (ploidy,) containing alleles encoded as + type `int` with values < 0 indicating a missing allele. + _ + Dummy variable of type `uint64` with length 3 which determines the + number of categories returned (this should always be 3). + + Note + ---- + This method is not suitable for mixed-ploidy genotypes. + + Returns + ------- + counts : ndarray + Counts of homozygous reference, homozygous alternate, and heterozygous genotypes. + """ + out[:] = 0 + for i in range(len(genotypes)): + index = _classify_hom(genotypes[i]) + if index >= 0: + out[index] += 1 diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index c2bb05ad9..739211880 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -430,7 +430,11 @@ def Tajimas_D( [1.10393559, 1.10393559]]) """ ds = define_variable_if_absent( - ds, variables.variant_allele_count, variant_allele_count, count_variant_alleles + ds, + variables.variant_allele_count, + variant_allele_count, + count_variant_alleles, + using=variables.call_genotype, ) ds = define_variable_if_absent( ds, variables.stat_diversity, stat_diversity, diversity diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 621cea669..4f15d0061 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -7,6 +7,7 @@ import xarray as xr from xarray import Dataset +from sgkit import variables from sgkit.stats.aggregation import ( call_allele_frequencies, cohort_allele_frequencies, @@ -39,26 +40,44 @@ def get_dataset( return ds -def test_count_variant_alleles__single_variant_single_sample(): - ds = count_variant_alleles(get_dataset([[[1, 0]]])) +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__single_variant_single_sample(using): + ds = count_variant_alleles(get_dataset([[[1, 0]]]), using=using) assert "call_genotype" in ds ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1]])) -def test_count_variant_alleles__multi_variant_single_sample(): - ds = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__multi_variant_single_sample(using): + ds = count_variant_alleles( + get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]), + using=using, + ) ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]])) -def test_count_variant_alleles__single_variant_multi_sample(): - ds = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__single_variant_multi_sample(using): + ds = count_variant_alleles( + get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]), + using=using, + ) ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[4, 4]])) -def test_count_variant_alleles__multi_variant_multi_sample(): +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__multi_variant_multi_sample(using): ds = count_variant_alleles( get_dataset( [ @@ -67,13 +86,17 @@ def test_count_variant_alleles__multi_variant_multi_sample(): [[1, 1], [0, 1], [1, 0]], [[1, 1], [1, 1], [1, 1]], ] - ) + ), + using=using, ) ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]])) -def test_count_variant_alleles__missing_data(): +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__missing_data(using): ds = count_variant_alleles( get_dataset( [ @@ -82,13 +105,17 @@ def test_count_variant_alleles__missing_data(): [[1, 1], [-1, -1], [-1, 0]], [[1, 1], [1, 1], [1, 1]], ] - ) + ), + using=using, ) ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]])) -def test_count_variant_alleles__higher_ploidy(): +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__higher_ploidy(using): ds = count_variant_alleles( get_dataset( [ @@ -97,31 +124,51 @@ def test_count_variant_alleles__higher_ploidy(): ], n_allele=4, n_ploidy=3, - ) + ), + using=using, ) ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]])) -def test_count_variant_alleles__chunked(): +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__chunked(using): rs = np.random.RandomState(0) calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) - ac1 = count_variant_alleles(ds) + ac1 = count_variant_alleles(ds, using=using) # Coerce from numpy to multiple chunks in all dimensions ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) - ac2 = count_variant_alleles(ds) + ac2 = count_variant_alleles(ds, using=using) assert isinstance(ac2["variant_allele_count"].data, da.Array) xr.testing.assert_equal(ac1, ac2) -def test_count_variant_alleles__no_merge(): - ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False) +@pytest.mark.parametrize( + "using", [variables.call_allele_count, variables.call_genotype] +) +def test_count_variant_alleles__no_merge(using): + ds = count_variant_alleles( + get_dataset([[[1, 0]]]), + merge=False, + using=using, + ) assert "call_genotype" not in ds ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1]])) +def test_count_variant_alleles__raise_on_unknown_using(): + ds = simulate_genotype_call_dataset(n_variant=1, n_sample=2) + options = {variables.call_genotype, variables.call_allele_count} + with pytest.raises( + ValueError, match=f"The 'using' argument must be one of {options}." + ): + count_variant_alleles(ds, using="unknown") + + def test_count_call_alleles__single_variant_single_sample(): ds = count_call_alleles(get_dataset([[[1, 0]]])) ac = ds["call_allele_count"] @@ -683,15 +730,75 @@ def test_variant_stats(precompute_variant_allele_count): ) -@pytest.mark.parametrize("precompute_variant_allele_count", [False, True]) -def test_sample_stats(precompute_variant_allele_count): +def test_variant_stats__multi_allelic(): + ds = simulate_genotype_call_dataset(n_variant=2, n_sample=4, n_allele=4, seed=0) + ds["call_genotype"].data = [ + [[0, 0], [0, 0], [1, 1], [2, 2]], + [[0, 0], [2, 3], [0, -1], [-1, 2]], + ] + vs = variant_stats(ds) + np.testing.assert_equal(vs["variant_n_called"], np.array([4, 2])) + np.testing.assert_equal(vs["variant_call_rate"], np.array([1, 1 / 2])) + np.testing.assert_equal(vs["variant_n_hom_ref"], np.array([2, 1])) + np.testing.assert_equal(vs["variant_n_hom_alt"], np.array([2, 0])) + np.testing.assert_equal(vs["variant_n_het"], np.array([0, 1])) + np.testing.assert_equal(vs["variant_n_non_ref"], np.array([2, 1])) + np.testing.assert_equal( + vs["variant_allele_count"], np.array([[4, 2, 2, 0], [3, 0, 2, 1]]) + ) + np.testing.assert_equal(vs["variant_allele_total"], np.array([8, 6])) + np.testing.assert_equal( + vs["variant_allele_frequency"], + np.array([[4 / 8, 2 / 8, 2 / 8, 0 / 8], [3 / 6, 0 / 6, 2 / 6, 1 / 6]]), + ) + + +def test_variant_stats__tetraploid(): + ds = simulate_genotype_call_dataset(n_variant=2, n_sample=3, n_ploidy=4, seed=0) + ds["call_genotype"].data = [ + [[0, 0, 0, 0], [0, 0, 0, 1], [1, 1, 1, 1]], + [[0, 0, 1, 1], [0, 1, 1, 1], [0, 0, -1, 0]], + ] + vs = variant_stats(ds) + np.testing.assert_equal(vs["variant_n_called"], np.array([3, 2])) + np.testing.assert_equal(vs["variant_call_rate"], np.array([1, 2 / 3])) + np.testing.assert_equal(vs["variant_n_hom_ref"], np.array([1, 0])) + np.testing.assert_equal(vs["variant_n_hom_alt"], np.array([1, 0])) + np.testing.assert_equal(vs["variant_n_het"], np.array([1, 2])) + np.testing.assert_equal(vs["variant_n_non_ref"], np.array([2, 2])) + np.testing.assert_equal(vs["variant_allele_count"], np.array([[7, 5], [6, 5]])) + np.testing.assert_equal(vs["variant_allele_total"], np.array([12, 11])) + np.testing.assert_equal( + vs["variant_allele_frequency"], + np.array([[7 / 12, 5 / 12], [6 / 11, 5 / 11]]), + ) + + +@pytest.mark.parametrize( + "chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1), (100, 10, 1)] +) +def test_variant_stats__chunks(chunks): + ds = simulate_genotype_call_dataset( + n_variant=1000, n_sample=30, missing_pct=0.01, seed=0 + ) + expect = variant_stats(ds, merge=False).compute() + ds["call_genotype"] = ds["call_genotype"].chunk(chunks) + actual = variant_stats(ds, merge=False).compute() + assert actual.equals(expect) + + +def test_variant_stats__raise_on_mixed_ploidy(): + ds = simulate_genotype_call_dataset(n_variant=2, n_sample=2, n_ploidy=3, seed=0) + ds["call_genotype"].attrs["mixed_ploidy"] = True + with pytest.raises(ValueError, match="Mixed-ploidy dataset"): + variant_stats(ds) + + +def test_sample_stats(): ds = get_dataset( [[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]] ) - if precompute_variant_allele_count: - ds = count_variant_alleles(ds) ss = sample_stats(ds) - np.testing.assert_equal(ss["sample_n_called"], np.array([3, 3])) np.testing.assert_equal(ss["sample_call_rate"], np.array([0.75, 0.75])) np.testing.assert_equal(ss["sample_n_hom_ref"], np.array([0, 1])) @@ -700,6 +807,43 @@ def test_sample_stats(precompute_variant_allele_count): np.testing.assert_equal(ss["sample_n_non_ref"], np.array([3, 2])) +def test_sample_stats__multi_allelic(): + ds = simulate_genotype_call_dataset(n_variant=2, n_sample=4, n_allele=4, seed=0) + ds["call_genotype"].data = [ + [[0, 0], [0, 0], [1, 1], [2, 2]], + [[0, 0], [2, 3], [0, -1], [-1, 2]], + ] + vs = sample_stats(ds) + np.testing.assert_equal(vs["sample_n_called"], np.array([2, 2, 1, 1])) + np.testing.assert_equal(vs["sample_call_rate"], np.array([1, 1, 0.5, 0.5])) + np.testing.assert_equal(vs["sample_n_hom_ref"], np.array([2, 1, 0, 0])) + np.testing.assert_equal(vs["sample_n_hom_alt"], np.array([0, 0, 1, 1])) + np.testing.assert_equal(vs["sample_n_het"], np.array([0, 1, 0, 0])) + np.testing.assert_equal(vs["sample_n_non_ref"], np.array([0, 1, 1, 1])) + + +def test_sample_stats__tetraploid(): + ds = simulate_genotype_call_dataset(n_variant=2, n_sample=3, n_ploidy=4, seed=0) + ds["call_genotype"].data = [ + [[0, 0, 0, 0], [0, 0, 0, 1], [1, 1, 1, 1]], + [[0, 0, 1, 1], [0, 1, 1, 1], [0, 0, -1, 0]], + ] + vs = sample_stats(ds) + np.testing.assert_equal(vs["sample_n_called"], np.array([2, 2, 1])) + np.testing.assert_equal(vs["sample_call_rate"], np.array([1, 1, 0.5])) + np.testing.assert_equal(vs["sample_n_hom_ref"], np.array([1, 0, 0])) + np.testing.assert_equal(vs["sample_n_hom_alt"], np.array([0, 0, 1])) + np.testing.assert_equal(vs["sample_n_het"], np.array([1, 2, 0])) + np.testing.assert_equal(vs["sample_n_non_ref"], np.array([1, 2, 1])) + + +def test_sample_stats__raise_on_mixed_ploidy(): + ds = simulate_genotype_call_dataset(n_variant=2, n_sample=2, n_ploidy=3, seed=0) + ds["call_genotype"].attrs["mixed_ploidy"] = True + with pytest.raises(ValueError, match="Mixed-ploidy dataset"): + sample_stats(ds) + + def test_infer_call_ploidy(): ds = get_dataset( [