Skip to content

Commit

Permalink
Support genotypes in queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Will-Tyler committed Oct 3, 2024
1 parent 5c134fe commit 9929543
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
1 change: 1 addition & 0 deletions tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
(r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz"),
(r"query -f '[%CHROM\t]\n'", "sample.vcf.gz"),
(r"query -f '[%CHROM\t]\n' -i 'POS=112'", "sample.vcf.gz"),
(r"query -f '%CHROM\t%POS\t%REF\t%ALT[\t%GT]\n'", "sample.vcf.gz"),
],
)
def test_output(tmp_path, args, vcf_name):
Expand Down
66 changes: 63 additions & 3 deletions vcztools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyparsing as pp
import zarr

from vcztools import constants
from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser
from vcztools.utils import open_file_like, vcf_name_to_vcz_name

Expand Down Expand Up @@ -83,12 +84,65 @@ def __call__(self, *args, **kwargs):

yield from self._generator(args[0])

def _compose_gt_generator(self) -> Callable:
def generate(root):
gt_zarray = root["call_genotype"]
v_chunk_size = gt_zarray.chunks[0]

if "call_genotype_phased" in root:
phase_zarray = root["call_genotype_phased"]
assert gt_zarray.chunks[:2] == phase_zarray.chunks
assert gt_zarray.shape[:2] == phase_zarray.shape

for v_chunk_index in range(gt_zarray.cdata_shape[0]):
start = v_chunk_index * v_chunk_size
end = start + v_chunk_size

for gt_row, phase in zip(
gt_zarray[start:end], phase_zarray[start:end]
):

def stringify(gt_and_phase: tuple):
gt, phase = gt_and_phase
gt = [
str(allele) if allele != constants.INT_MISSING else "."
for allele in gt
if allele != constants.INT_FILL
]
separator = "|" if phase else "/"
return separator.join(gt)

gt_row = gt_row.tolist()
yield map(stringify, zip(gt_row, phase))
else:
for v_chunk_index in range(gt_zarray.cdata_shape[0]):
start = v_chunk_index * v_chunk_size
end = start + v_chunk_size

for gt_row in gt_zarray[start:end]:

def stringify(gt: list[int]):
gt = [
str(allele) if allele != constants.INT_MISSING else "."
for allele in gt
if allele != constants.INT_FILL
]
return "/".join(gt)

gt_row = gt_row.tolist()
yield map(stringify, gt_row)

return generate

def _compose_tag_generator(
self, tag: str, *, subfield=False, sample_loop=False
) -> Callable:
assert tag.startswith("%")
tag = tag[1:]

if tag == "GT":
return self._compose_gt_generator()

def generate(root):
vcz_names = set(name for name, _zarray in root.items())
vcz_name = vcf_name_to_vcz_name(vcz_names, tag)
Expand Down Expand Up @@ -124,16 +178,22 @@ def generate(root):
row = "."
else:
row = f"{row:g}"
if not subfield and (
isinstance(row, np.ndarray) or isinstance(row, list)
if (
not subfield
and not sample_loop
and (isinstance(row, np.ndarray) or isinstance(row, list))
):
row = ",".join(map(str, row))

result = row if not is_missing else "."

if sample_loop:
sample_count = root["sample_id"].shape[0]
yield itertools.repeat(row, sample_count)

if isinstance(row, np.ndarray) or isinstance(row, list):
yield row
else:
yield itertools.repeat(row, sample_count)
else:
yield result

Expand Down

0 comments on commit 9929543

Please sign in to comment.