From c359718bf7003b6ce8767207bf11c5b2df136bbe Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 4 Sep 2024 16:10:25 +0100 Subject: [PATCH] Use sgkit.distarray for Hardy-Weinberg Equilibrium --- .github/workflows/cubed.yml | 2 +- sgkit/stats/aggregation.py | 5 ++--- sgkit/stats/hwe.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index 835565b8d..bdcf3f242 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,4 +30,4 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed + pytest -v sgkit/tests/test_{aggregation,hwe}.py -k 'test_count_call_alleles or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed \ No newline at end of file diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index d0338955f..9360e318c 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -457,9 +457,8 @@ def genotype_coords( G = da.map_blocks(_index_as_genotype, X, K, new_axis=1, chunks=chunks) # allow enough room for all alleles and separators dtype = "|S{}".format(max_chars * ploidy + ploidy - 1) - S = da.map_blocks( - genotype_as_bytes, G, False, max_chars, drop_axis=1, dtype=dtype - ).astype("U") + S = da.map_blocks(genotype_as_bytes, G, False, max_chars, drop_axis=1, dtype=dtype) + S = da.astype(S, "U{}".format(max_chars * ploidy + ploidy - 1)) new_ds = create_dataset({variables.genotype_id: ("genotypes", S)}) ds = conditional_merge_datasets(ds, new_ds, merge) if assign_coords: diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index 8b9f1ffeb..3dbfae8a6 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -1,9 +1,9 @@ from typing import Hashable, Optional -import dask.array as da import numpy as np from xarray import Dataset +import sgkit.distarray as da from sgkit import variables from sgkit.accelerate import numba_jit from sgkit.stats.aggregation import count_variant_genotypes