From 992954332e320c069ba8f2e74a7c631b8cc78be4 Mon Sep 17 00:00:00 2001 From: willtyler Date: Thu, 3 Oct 2024 17:19:32 +0000 Subject: [PATCH] Support genotypes in queries --- tests/test_bcftools_validation.py | 1 + vcztools/query.py | 66 +++++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index a3d742f..f17f089 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -130,6 +130,7 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file): (r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz"), (r"query -f '[%CHROM\t]\n'", "sample.vcf.gz"), (r"query -f '[%CHROM\t]\n' -i 'POS=112'", "sample.vcf.gz"), + (r"query -f '%CHROM\t%POS\t%REF\t%ALT[\t%GT]\n'", "sample.vcf.gz"), ], ) def test_output(tmp_path, args, vcf_name): diff --git a/vcztools/query.py b/vcztools/query.py index a93dbb1..0200a96 100644 --- a/vcztools/query.py +++ b/vcztools/query.py @@ -7,6 +7,7 @@ import pyparsing as pp import zarr +from vcztools import constants from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser from vcztools.utils import open_file_like, vcf_name_to_vcz_name @@ -83,12 +84,65 @@ def __call__(self, *args, **kwargs): yield from self._generator(args[0]) + def _compose_gt_generator(self) -> Callable: + def generate(root): + gt_zarray = root["call_genotype"] + v_chunk_size = gt_zarray.chunks[0] + + if "call_genotype_phased" in root: + phase_zarray = root["call_genotype_phased"] + assert gt_zarray.chunks[:2] == phase_zarray.chunks + assert gt_zarray.shape[:2] == phase_zarray.shape + + for v_chunk_index in range(gt_zarray.cdata_shape[0]): + start = v_chunk_index * v_chunk_size + end = start + v_chunk_size + + for gt_row, phase in zip( + gt_zarray[start:end], phase_zarray[start:end] + ): + + def stringify(gt_and_phase: tuple): + gt, phase = gt_and_phase + gt = [ + str(allele) if allele != constants.INT_MISSING else "." + for allele in gt + if allele != constants.INT_FILL + ] + separator = "|" if phase else "/" + return separator.join(gt) + + gt_row = gt_row.tolist() + yield map(stringify, zip(gt_row, phase)) + else: + for v_chunk_index in range(gt_zarray.cdata_shape[0]): + start = v_chunk_index * v_chunk_size + end = start + v_chunk_size + + for gt_row in gt_zarray[start:end]: + + def stringify(gt: list[int]): + gt = [ + str(allele) if allele != constants.INT_MISSING else "." + for allele in gt + if allele != constants.INT_FILL + ] + return "/".join(gt) + + gt_row = gt_row.tolist() + yield map(stringify, gt_row) + + return generate + def _compose_tag_generator( self, tag: str, *, subfield=False, sample_loop=False ) -> Callable: assert tag.startswith("%") tag = tag[1:] + if tag == "GT": + return self._compose_gt_generator() + def generate(root): vcz_names = set(name for name, _zarray in root.items()) vcz_name = vcf_name_to_vcz_name(vcz_names, tag) @@ -124,8 +178,10 @@ def generate(root): row = "." else: row = f"{row:g}" - if not subfield and ( - isinstance(row, np.ndarray) or isinstance(row, list) + if ( + not subfield + and not sample_loop + and (isinstance(row, np.ndarray) or isinstance(row, list)) ): row = ",".join(map(str, row)) @@ -133,7 +189,11 @@ def generate(root): if sample_loop: sample_count = root["sample_id"].shape[0] - yield itertools.repeat(row, sample_count) + + if isinstance(row, np.ndarray) or isinstance(row, list): + yield row + else: + yield itertools.repeat(row, sample_count) else: yield result