Skip to content

Commit

Permalink
Tighten up GT field testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 16, 2024
1 parent f66fc24 commit 01e69ae
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 24 additions & 1 deletion tests/test_cpython_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def example_format_data(num_variants, num_samples):
return d


def example_gt_data(num_variants, num_samples, ploidy=2):
return {
"gt": np.arange(num_variants * num_samples * ploidy)
.astype(np.int32)
.reshape(num_variants, num_samples, ploidy),
"gt_phased": (np.arange(num_variants * num_samples) % 2)
.reshape(num_variants, num_samples)
.astype(bool),
}


def example_encoder(num_variants=1, num_samples=0, add_info=True):
encoder = _vcztools.VcfEncoder(
num_variants, num_samples, **example_fixed_data(num_variants, num_samples)
Expand All @@ -57,6 +68,9 @@ def example_encoder(num_variants=1, num_samples=0, add_info=True):
if num_samples > 0:
for name, data in example_format_data(num_variants, num_samples).items():
encoder.add_format_field(name, data)

gt_data = example_gt_data(num_variants, num_samples)
encoder.add_gt_field(gt_data["gt"], gt_data["gt_phased"])
# import sys
# encoder.print_state(sys.stdout)
return encoder
Expand Down Expand Up @@ -330,13 +344,22 @@ def test_add_gt_field_unsupported_width(self):
np.zeros((1, 1, 1), dtype=np.int64), np.zeros((1, 1), dtype=bool)
)

def test_add_gt_field_zero_ploidy(self):
encoder = example_encoder(1, 1)
with pytest.raises(ValueError, match="-204"):
encoder.add_gt_field(
np.zeros((1, 1, 0), dtype=np.int64), np.zeros((1, 1), dtype=bool)
)


class TestArrays:
def test_stored_data_equal(self):
num_variants = 20
num_samples = 10
fixed_data = example_fixed_data(num_variants, num_samples)
gt_data = example_gt_data(num_variants, num_samples)
encoder = _vcztools.VcfEncoder(num_variants, num_samples, **fixed_data)
encoder.add_gt_field(gt_data["gt"], gt_data["gt_phased"])
info_data = example_info_data(num_variants)
format_data = example_format_data(num_variants, num_samples)
for name, data in info_data.items():
Expand All @@ -346,7 +369,7 @@ def test_stored_data_equal(self):
all_data = {**fixed_data}
for name, array in info_data.items():
all_data[f"INFO/{name}"] = array
for name, array in format_data.items():
for name, array in {**format_data, **gt_data}.items():
all_data[f"FORMAT/{name}"] = array
encoder_arrays = encoder.arrays
assert set(encoder.arrays.keys()) == set(all_data.keys())
Expand Down
6 changes: 3 additions & 3 deletions vcztools/_vcztoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ np_type_to_vcz_type(const char *name, PyArrayObject *array)
}

static int
check_dtype(const char *name, PyArrayObject *array, int type) {
check_dtype(const char *name, PyArrayObject *array, int type)
{
if (PyArray_DTYPE(array)->type_num != type) {
PyErr_Format(PyExc_ValueError, "Wrong dtype for %s", name);
return -1;
Expand Down Expand Up @@ -411,8 +412,7 @@ VcfEncoder_add_gt_field(VcfEncoder *self, PyObject *args)
goto out;
}
if (PyArray_DTYPE(gt)->kind != 'i') {
PyErr_Format(
PyExc_ValueError, "Array 'gt' has unsupported array dtype");
PyErr_Format(PyExc_ValueError, "Array 'gt' has unsupported array dtype");
goto out;
}
if (check_dtype("gt_phased", gt_phased, NPY_BOOL) != 0) {
Expand Down

0 comments on commit 01e69ae

Please sign in to comment.