diff --git a/tests/test_cpython_interface.py b/tests/test_cpython_interface.py index a2c5bf6..533c64c 100644 --- a/tests/test_cpython_interface.py +++ b/tests/test_cpython_interface.py @@ -47,6 +47,17 @@ def example_format_data(num_variants, num_samples): return d +def example_gt_data(num_variants, num_samples, ploidy=2): + return { + "gt": np.arange(num_variants * num_samples * ploidy) + .astype(np.int32) + .reshape(num_variants, num_samples, ploidy), + "gt_phased": (np.arange(num_variants * num_samples) % 2) + .reshape(num_variants, num_samples) + .astype(bool), + } + + def example_encoder(num_variants=1, num_samples=0, add_info=True): encoder = _vcztools.VcfEncoder( num_variants, num_samples, **example_fixed_data(num_variants, num_samples) @@ -57,6 +68,9 @@ def example_encoder(num_variants=1, num_samples=0, add_info=True): if num_samples > 0: for name, data in example_format_data(num_variants, num_samples).items(): encoder.add_format_field(name, data) + + gt_data = example_gt_data(num_variants, num_samples) + encoder.add_gt_field(gt_data["gt"], gt_data["gt_phased"]) # import sys # encoder.print_state(sys.stdout) return encoder @@ -330,13 +344,22 @@ def test_add_gt_field_unsupported_width(self): np.zeros((1, 1, 1), dtype=np.int64), np.zeros((1, 1), dtype=bool) ) + def test_add_gt_field_zero_ploidy(self): + encoder = example_encoder(1, 1) + with pytest.raises(ValueError, match="-204"): + encoder.add_gt_field( + np.zeros((1, 1, 0), dtype=np.int64), np.zeros((1, 1), dtype=bool) + ) + class TestArrays: def test_stored_data_equal(self): num_variants = 20 num_samples = 10 fixed_data = example_fixed_data(num_variants, num_samples) + gt_data = example_gt_data(num_variants, num_samples) encoder = _vcztools.VcfEncoder(num_variants, num_samples, **fixed_data) + encoder.add_gt_field(gt_data["gt"], gt_data["gt_phased"]) info_data = example_info_data(num_variants) format_data = example_format_data(num_variants, num_samples) for name, data in info_data.items(): @@ -346,7 +369,7 @@ def test_stored_data_equal(self): all_data = {**fixed_data} for name, array in info_data.items(): all_data[f"INFO/{name}"] = array - for name, array in format_data.items(): + for name, array in {**format_data, **gt_data}.items(): all_data[f"FORMAT/{name}"] = array encoder_arrays = encoder.arrays assert set(encoder.arrays.keys()) == set(all_data.keys()) diff --git a/vcztools/_vcztoolsmodule.c b/vcztools/_vcztoolsmodule.c index 3130628..fa09679 100644 --- a/vcztools/_vcztoolsmodule.c +++ b/vcztools/_vcztoolsmodule.c @@ -174,7 +174,8 @@ np_type_to_vcz_type(const char *name, PyArrayObject *array) } static int -check_dtype(const char *name, PyArrayObject *array, int type) { +check_dtype(const char *name, PyArrayObject *array, int type) +{ if (PyArray_DTYPE(array)->type_num != type) { PyErr_Format(PyExc_ValueError, "Wrong dtype for %s", name); return -1; @@ -411,8 +412,7 @@ VcfEncoder_add_gt_field(VcfEncoder *self, PyObject *args) goto out; } if (PyArray_DTYPE(gt)->kind != 'i') { - PyErr_Format( - PyExc_ValueError, "Array 'gt' has unsupported array dtype"); + PyErr_Format(PyExc_ValueError, "Array 'gt' has unsupported array dtype"); goto out; } if (check_dtype("gt_phased", gt_phased, NPY_BOOL) != 0) {