Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query format #64

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def test_vcf_output(tmp_path, args, vcf_file):
("index --nrecords", "1kg_2020_chrM.vcf.gz"),
("query -l", "sample.vcf.gz"),
("query --list-samples", "1kg_2020_chrM.vcf.gz"),
(r"query -f 'A\n'", "sample.vcf.gz"),
(r"query -f '%CHROM:%POS\n'", "sample.vcf.gz"),
(r"query -f '%INFO/DP\n'", "sample.vcf.gz"),
(r"query -f '%AC{0}\n'", "sample.vcf.gz"),
(r"query -f '%REF\t%ALT\n'", "sample.vcf.gz"),
(r"query -f '%ALT{1}\n'", "sample.vcf.gz"),
(r"query -f '%ID\n'", "sample.vcf.gz"),
(r"query -f '%QUAL\n'", "sample.vcf.gz"),
(r"query -f '%FILTER\n'", "sample.vcf.gz"),
(r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"),
],
)
def test_output(tmp_path, args, vcf_name):
Expand Down
90 changes: 89 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pathlib
from io import StringIO

import pyparsing as pp
import pytest
import zarr

from tests.utils import vcz_path_cache
from vcztools.query import list_samples
from vcztools.query import QueryFormatGenerator, QueryFormatParser, list_samples


def test_list_samples(tmp_path):
Expand All @@ -13,3 +17,87 @@ def test_list_samples(tmp_path):
with StringIO() as output:
list_samples(vcz_path, output)
assert output.getvalue() == expected_output


class TestQueryFormatParser:
@pytest.fixture()
def parser(self):
return QueryFormatParser()

@pytest.mark.parametrize(
("expression", "expected_result"),
[
("%CHROM", ["%CHROM"]),
(r"\n", ["\n"]),
(r"\t", ["\t"]),
(r"%CHROM\n", ["%CHROM", "\n"]),
("%CHROM %POS %REF", ["%CHROM", " ", "%POS", " ", "%REF"]),
(r"%CHROM %POS0 %REF\n", ["%CHROM", " ", "%POS0", " ", "%REF", "\n"]),
(
r"%CHROM\t%POS\t%REF\t%ALT{0}\n",
["%CHROM", "\t", "%POS", "\t", "%REF", "\t", ["%ALT", 0], "\n"],
),
(
r"%CHROM\t%POS0\t%END\t%ID\n",
["%CHROM", "\t", "%POS0", "\t", "%END", "\t", "%ID", "\n"],
),
(r"%CHROM:%POS\n", ["%CHROM", ":", "%POS", "\n"]),
(r"%AC{1}\n", [["%AC", 1], "\n"]),
(
r"Read depth: %INFO/DP\n",
["Read", " ", "depth:", " ", "%INFO/DP", "\n"],
),
],
)
def test_valid_expressions(self, parser, expression, expected_result):
assert parser(expression).as_list() == expected_result

@pytest.mark.parametrize(
"expression",
[
"%ac",
"%AC {1}",
"% CHROM",
],
)
def test_invalid_expressions(self, parser, expression):
with pytest.raises(pp.ParseException):
parser(expression)


class TestQueryFormatEvaluator:
@pytest.fixture()
def root(self):
vcf_path = pathlib.Path("tests/data/vcf/sample.vcf.gz")
vcz_path = vcz_path_cache(vcf_path)
return zarr.open(vcz_path, mode="r")

@pytest.mark.parametrize(
("query_format", "expected_result"),
[
(r"A\t", "A\t" * 9),
(r"CHROM", "CHROM" * 9),
(
r"%CHROM:%POS\n",
"19:111\n19:112\n20:14370\n20:17330\n20:1110696\n20:1230237\n20:1234567\n20:1235237\nX:10\n",
),
(r"%INFO/DP\n", ".\n.\n14\n11\n10\n13\n9\n.\n.\n"),
(r"%AC\n", ".\n.\n.\n.\n.\n.\n3,1\n.\n.\n"),
(r"%AC{0}\n", ".\n.\n.\n.\n.\n.\n3\n.\n.\n"),
],
)
def test(self, root, query_format, expected_result):
generator = QueryFormatGenerator(query_format)
result = "".join(generator(root))
assert result == expected_result

@pytest.mark.parametrize(
("query_format", "expected_result"),
[(r"%QUAL\n", "9.6\n10\n29\n3\n67\n47\n50\n.\n10\n")],
)
def test_with_parse_results(self, root, query_format, expected_result):
parser = QueryFormatParser()
parse_results = parser(query_format)
generator = QueryFormatGenerator(parse_results)
result = "".join(generator(root))
assert result == expected_result
5 changes: 4 additions & 1 deletion vcztools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ def index(path, nrecords):
is_flag=True,
help="List the sample IDs and exit.",
)
def query(path, list_samples):
@click.option("-f", "--format", type=str, help="The format of the output.")
def query(path, list_samples, format):
if list_samples:
query_module.list_samples(path)
return

query_module.write_query(path, query_format=format)


@click.command
@click.argument("path", type=click.Path())
Expand Down
170 changes: 169 additions & 1 deletion vcztools/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import functools
import itertools
import math
from typing import Callable, Union

import numpy as np
import pyparsing as pp
import zarr

from vcztools.utils import open_file_like
from vcztools.utils import open_file_like, vcf_name_to_vcz_name


def list_samples(vcz_path, output=None):
Expand All @@ -9,3 +16,164 @@ def list_samples(vcz_path, output=None):
with open_file_like(output) as output:
sample_ids = root["sample_id"][:]
print("\n".join(sample_ids), file=output)


class QueryFormatParser:
def __init__(self):
info_tag_pattern = pp.Combine(
pp.Literal("%INFO/") + pp.Word(pp.srange("[A-Z]"))
)
tag_pattern = info_tag_pattern | pp.Combine(
pp.Literal("%") + pp.Regex(r"[A-Z]+\d?")
)
subscript_pattern = pp.Group(
tag_pattern
+ pp.Literal("{").suppress()
+ pp.common.integer
+ pp.Literal("}").suppress()
)
newline_pattern = pp.Literal("\\n").set_parse_action(pp.replace_with("\n"))
tab_pattern = pp.Literal("\\t").set_parse_action(pp.replace_with("\t"))
pattern = pp.ZeroOrMore(
subscript_pattern
| tag_pattern
| newline_pattern
| tab_pattern
| pp.White()
| pp.Word(pp.printables, exclude_chars=r"\{}[]%")
)
pattern = pattern.leave_whitespace()

self._parser = functools.partial(pattern.parse_string, parse_all=True)

def __call__(self, *args, **kwargs):
assert len(args) == 1
assert not kwargs

return self._parser(args[0])


class QueryFormatGenerator:
def __init__(self, query_format: Union[str, pp.ParseResults]):
if isinstance(query_format, str):
parser = QueryFormatParser()
parse_results = parser(query_format)
else:
assert isinstance(query_format, pp.ParseResults)
parse_results = query_format

self._generator = self._compose_generator(parse_results)

def __call__(self, *args, **kwargs):
assert len(args) == 1
assert not kwargs

yield from self._generator(args[0])

def _compose_tag_generator(self, tag: str, *, subscript=False) -> Callable:
assert tag.startswith("%")
tag = tag[1:]

def generate(root):
vcz_names = set(name for name, _zarray in root.items())
vcz_name = vcf_name_to_vcz_name(vcz_names, tag)
zarray = root[vcz_name]
contig_ids = root["contig_id"][:] if tag == "CHROM" else None
filter_ids = root["filter_id"][:] if tag == "FILTER" else None
v_chunk_size = zarray.chunks[0]

for v_chunk_index in range(zarray.cdata_shape[0]):
start = v_chunk_index * v_chunk_size
end = start + v_chunk_size

for row in zarray[start:end]:
is_missing = np.any(row == -1)

if tag == "CHROM":
assert contig_ids is not None
row = contig_ids[row]
if tag == "REF":
row = row[0]
if tag == "ALT":
row = [allele for allele in row[1:] if allele]
row = row or "."
if tag == "FILTER":
assert filter_ids is not None

if np.any(row):
row = filter_ids[row]
else:
row = "."
if tag == "QUAL":
if math.isnan(row):
row = "."
else:
row = f"{row:g}"
if not subscript and (
isinstance(row, np.ndarray) or isinstance(row, list)
):
row = ",".join(map(str, row))

yield row if not is_missing else "."

return generate

def _compose_subscript_generator(self, parse_results: pp.ParseResults) -> Callable:
assert len(parse_results) == 2

tag, subscript_index = parse_results
tag_generator = self._compose_tag_generator(tag, subscript=True)

def generate(root):
for tag in tag_generator(root):
if isinstance(tag, str):
assert tag == "."
yield "."
else:
if subscript_index < len(tag):
yield tag[subscript_index]
else:
yield "."

return generate

def _compose_element_generator(
self, element: Union[str, pp.ParseResults]
) -> Callable:
if isinstance(element, pp.ParseResults):
return self._compose_subscript_generator(element)

assert isinstance(element, str)

if element.startswith("%"):
return self._compose_tag_generator(element)
else:

def generate(root):
variant_count = root["variant_position"].shape[0]
yield from itertools.repeat(element, variant_count)

return generate

def _compose_generator(self, parse_results: pp.ParseResults) -> Callable:
generators = (
self._compose_element_generator(element) for element in parse_results
)

def generate(root) -> str:
iterables = (generator(root) for generator in generators)

for results in zip(*iterables):
results = map(str, results)
yield "".join(results)

return generate


def write_query(vcz, output=None, *, query_format: str):
root = zarr.open(vcz, mode="r")
generator = QueryFormatGenerator(query_format)

with open_file_like(output) as output:
for result in generator(root):
print(result, sep="", end="", file=output)