Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recompute INFO fields on sample selection #77

Merged
merged 8 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified tests/data/vcf/sample.vcf.gz
Binary file not shown.
Binary file added tests/data/vcf/sample.vcf.gz.csi
Binary file not shown.
Binary file removed tests/data/vcf/sample.vcf.gz.tbi
Binary file not shown.
4 changes: 4 additions & 0 deletions tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def run_vcztools(args: str) -> str:
"tests/data/txt/samples.txt",
"sample.vcf.gz"),
("view -I --no-version -S tests/data/txt/samples.txt", "sample.vcf.gz"),
("view --no-version -s NA00001", "sample.vcf.gz"),
("view --no-version -s NA00001,NA00003", "sample.vcf.gz"),
("view --no-version -s HG00096", "1kg_2020_chrM.vcf.gz"),
("view --no-version -s '' --force-samples", "sample.vcf.gz")
]
)
# fmt: on
Expand Down
4 changes: 2 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def root(self):
"19:111\n19:112\n20:14370\n20:17330\n20:1110696\n20:1230237\n20:1234567\n20:1235237\nX:10\n",
),
(r"%INFO/DP\n", ".\n.\n14\n11\n10\n13\n9\n.\n.\n"),
(r"%AC\n", ".\n.\n.\n.\n.\n.\n3,1\n.\n.\n"),
(r"%AC{0}\n", ".\n.\n.\n.\n.\n.\n3\n.\n.\n"),
(r"%AC\n", ".\n.\n.\n.\n.\n.\n1,1\n.\n.\n"),
(r"%AC{0}\n", ".\n.\n.\n.\n.\n.\n1\n.\n.\n"),
],
)
def test(self, root, query_format, expected_result):
Expand Down
38 changes: 37 additions & 1 deletion tests/test_vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import re
from io import StringIO

import numpy as np
import pytest
from cyvcf2 import VCF
from numpy.testing import assert_array_equal

from vcztools.vcf_writer import write_vcf
from vcztools.constants import INT_FILL, INT_MISSING
from vcztools.vcf_writer import _compute_info_fields, write_vcf

from .utils import assert_vcfs_close, vcz_path_cache

Expand Down Expand Up @@ -311,3 +313,37 @@ def test_write_vcf__set_header(tmp_path):
assert variant.genotypes is not None
count += 1
assert count == 9


def test_compute_info_fields():
gt = np.array([
[[0, 0], [0, 1], [1, 1]],
[[0, 0], [0, 2], [2, 2]],
[[0, 1], [1, 2], [2, 2]],
[[INT_MISSING, INT_MISSING], [INT_MISSING, INT_MISSING], [INT_FILL, INT_FILL]],
[[INT_MISSING, INT_MISSING], [0, 3], [INT_FILL, INT_FILL]],
])
alt = np.array([
[b"A", b"B", b""],
[b"A", b"B", b"C"],
[b"A", b"B", b"C"],
[b"", b"", b""],
[b"A", b"B", b"C"]
])
expected_result = {
"AC": np.array([
[3, 0, INT_FILL],
[0, 3, 0],
[2, 3, 0],
[INT_FILL, INT_FILL, INT_FILL],
[0, 0, 1],
]),
"AN": np.array([6, 6, 6, 0, 2]),
}

computed_info_fields = _compute_info_fields(gt, alt)

assert expected_result.keys() == computed_info_fields.keys()

for key in expected_result.keys():
np.testing.assert_array_equal(expected_result[key], computed_info_fields[key])
5 changes: 5 additions & 0 deletions vcztools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def query(path, list_samples, format):
default=None,
help="Regions to include.",
)
@click.option(
"--force-samples", is_flag=True, help="Only warn about unknown sample subsets."
)
@click.option(
"-I",
"--no-update",
Expand Down Expand Up @@ -133,6 +136,7 @@ def view(
no_version,
regions,
targets,
force_samples,
no_update,
samples,
samples_file,
Expand All @@ -159,6 +163,7 @@ def view(
no_version=no_version,
variant_regions=regions,
variant_targets=targets,
no_update=no_update,
samples=samples,
drop_genotypes=drop_genotypes,
include=include,
Expand Down
75 changes: 63 additions & 12 deletions vcztools/vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
search,
)

from . import _vcztools
from . import _vcztools, constants
from .constants import RESERVED_VARIABLE_NAMES
from .filter import FilterExpressionEvaluator, FilterExpressionParser

Expand Down Expand Up @@ -87,6 +87,7 @@ def write_vcf(
no_version: bool = False,
variant_regions=None,
variant_targets=None,
no_update=None,
samples=None,
drop_genotypes: bool = False,
include: Optional[str] = None,
Expand Down Expand Up @@ -153,6 +154,8 @@ def write_vcf(
else:
all_samples = root["sample_id"][:]
sample_ids = np.array(samples.split(","))
if np.all(sample_ids == np.array("")):
sample_ids = np.empty((0,))
samples_selection = search(all_samples, sample_ids)

if not no_header and vcf_header is None:
Expand Down Expand Up @@ -208,6 +211,8 @@ def write_vcf(
contigs,
filters,
output,
drop_genotypes=drop_genotypes,
no_update=no_update,
)
else:
contigs_u = root["contig_id"][:].astype("U").tolist()
Expand Down Expand Up @@ -270,6 +275,8 @@ def write_vcf(
contigs,
filters,
output,
drop_genotypes=drop_genotypes,
no_update=no_update,
)


Expand All @@ -287,7 +294,16 @@ def get_vchunk_array(zarray, v_chunk, mask, samples_selection=None):


def c_chunk_to_vcf(
root, v_chunk, v_mask_chunk, samples_selection, contigs, filters, output
root,
v_chunk,
v_mask_chunk,
samples_selection,
contigs,
filters,
output,
*,
drop_genotypes,
no_update,
):
chrom = contigs[get_vchunk_array(root.variant_contig, v_chunk, v_mask_chunk)]
# TODO check we don't truncate silently by doing this
Expand Down Expand Up @@ -328,10 +344,25 @@ def c_chunk_to_vcf(

gt = None
gt_phased = None
if "call_genotype" in root and num_samples != 0:

if "call_genotype" in root and not drop_genotypes:
array = root["call_genotype"]
gt = get_vchunk_array(array, v_chunk, v_mask_chunk, samples_selection)
if "call_genotype_phased" in root:

if samples_selection is not None and num_samples != 0:
gt = get_vchunk_array(array, v_chunk, v_mask_chunk, samples_selection)
else:
gt = get_vchunk_array(array, v_chunk, v_mask_chunk)

if not no_update and samples_selection is not None:
# Recompute INFO/AC and INFO/AN
info_fields |= _compute_info_fields(gt, alt)
if num_samples == 0:
gt = None
if (
"call_genotype_phased" in root
and not drop_genotypes
and (samples_selection is None or num_samples > 0)
):
array = root["call_genotype_phased"]
gt_phased = get_vchunk_array(
array, v_chunk, v_mask_chunk, samples_selection
Expand Down Expand Up @@ -365,13 +396,13 @@ def c_chunk_to_vcf(
array = array.reshape((num_variants, 1))
encoder.add_info_field(name, array)

for name, array in format_fields.items():
assert num_samples > 0
if array.dtype.kind in ("O", "U"):
array = array.astype("S")
if len(array.shape) == 2:
array = array.reshape((num_variants, num_samples, 1))
encoder.add_format_field(name, array)
if num_samples != 0:
for name, array in format_fields.items():
if array.dtype.kind in ("O", "U"):
array = array.astype("S")
if len(array.shape) == 2:
array = array.reshape((num_variants, num_samples, 1))
encoder.add_format_field(name, array)
# TODO: (1) make a guess at this based on number of fields and samples,
# and (2) log a DEBUG message when we have to double.
buflen = 1024
Expand Down Expand Up @@ -605,3 +636,23 @@ def _format_fields(header_str):
fields.remove("GT")
fields.insert(0, "GT")
return fields


def _compute_info_fields(gt: np.ndarray, alt: np.ndarray):
flatter_gt = gt.reshape((gt.shape[0], -1))
allele_count = alt.shape[1] + 1

def filter_and_bincount(values: np.ndarray):
positive = values[values > 0]
return np.bincount(positive, minlength=allele_count)[1:]

computed_ac = np.apply_along_axis(filter_and_bincount, 1, flatter_gt).astype(
np.int32
)
computed_ac[alt == b""] = constants.INT_FILL
computed_an = np.sum(flatter_gt >= 0, axis=1, dtype=np.int32)

return {
"AC": computed_ac,
"AN": computed_an,
}