From c552a84c4aeee6235bc1481e1b5a9953ae0794b1 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 3 Sep 2024 13:27:35 +0100 Subject: [PATCH] Use sgkit.distarray for sample_stats --- .github/workflows/cubed.yml | 2 +- sgkit/stats/aggregation.py | 19 ++++++++++--------- sgkit/tests/test_aggregation.py | 11 +++++++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index 3b1ab5b3e..835565b8d 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,4 +30,4 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed + pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index d3bd48596..d0338955f 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -803,22 +803,23 @@ def sample_stats( mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) if mixed_ploidy: raise ValueError("Mixed-ploidy dataset") - G = da.asarray(ds[call_genotype].data) + GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data) H = xr.DataArray( da.map_blocks( - count_hom, - G.transpose(1, 0, 2), + lambda *args: count_hom(*args)[:, np.newaxis, :], + GT, np.zeros(3, np.uint64), - drop_axis=(1, 2), - new_axis=1, + drop_axis=2, + new_axis=2, dtype=np.int64, - chunks=(G.chunks[1], 3), + chunks=(GT.chunks[0], 1, 3), ), - dims=["samples", "categories"], + dims=["samples", "variants", "categories"], ) - n_variant, _, _ = G.shape + H = H.sum(axis=1) + _, n_variant, _ = GT.shape n_called = H.sum(axis=-1) - call_rate = n_called / n_variant + call_rate = n_called.astype(float) / float(n_variant) n_hom_ref = H[:, 0] n_hom_alt = H[:, 1] n_het = H[:, 2] diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index f67d1e658..5ed538d03 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -857,6 +857,17 @@ def test_sample_stats__raise_on_mixed_ploidy(): sample_stats(ds) +@pytest.mark.parametrize("chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1)]) +def test_sample_stats__chunks(chunks): + ds = simulate_genotype_call_dataset( + n_variant=1000, n_sample=30, missing_pct=0.01, seed=0 + ) + expect = sample_stats(ds, merge=False).compute() + ds["call_genotype"] = ds["call_genotype"].chunk(chunks) + actual = sample_stats(ds, merge=False).compute() + assert actual.equals(expect) + + def test_infer_call_ploidy(): ds = get_dataset( [