diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index 3a5cb88..cbb80e2 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -76,6 +76,16 @@ def test_vcf_output(tmp_path, args, vcf_file): ("index --nrecords", "1kg_2020_chrM.vcf.gz"), ("query -l", "sample.vcf.gz"), ("query --list-samples", "1kg_2020_chrM.vcf.gz"), + (r"query -f 'A\n'", "sample.vcf.gz"), + (r"query -f '%CHROM:%POS\n'", "sample.vcf.gz"), + (r"query -f '%INFO/DP\n'", "sample.vcf.gz"), + (r"query -f '%AC{0}\n'", "sample.vcf.gz"), + (r"query -f '%REF\t%ALT\n'", "sample.vcf.gz"), + (r"query -f '%ALT{1}\n'", "sample.vcf.gz"), + (r"query -f '%ID\n'", "sample.vcf.gz"), + (r"query -f '%QUAL\n'", "sample.vcf.gz"), + (r"query -f '%FILTER\n'", "sample.vcf.gz"), + (r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"), ], ) def test_output(tmp_path, args, vcf_name): diff --git a/tests/test_query.py b/tests/test_query.py index ac0fbbb..26555fa 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,8 +1,12 @@ import pathlib from io import StringIO +import pyparsing as pp +import pytest +import zarr + from tests.utils import vcz_path_cache -from vcztools.query import list_samples +from vcztools.query import QueryFormatGenerator, QueryFormatParser, list_samples def test_list_samples(tmp_path): @@ -13,3 +17,87 @@ def test_list_samples(tmp_path): with StringIO() as output: list_samples(vcz_path, output) assert output.getvalue() == expected_output + + +class TestQueryFormatParser: + @pytest.fixture() + def parser(self): + return QueryFormatParser() + + @pytest.mark.parametrize( + ("expression", "expected_result"), + [ + ("%CHROM", ["%CHROM"]), + (r"\n", ["\n"]), + (r"\t", ["\t"]), + (r"%CHROM\n", ["%CHROM", "\n"]), + ("%CHROM %POS %REF", ["%CHROM", " ", "%POS", " ", "%REF"]), + (r"%CHROM %POS0 %REF\n", ["%CHROM", " ", "%POS0", " ", "%REF", "\n"]), + ( + r"%CHROM\t%POS\t%REF\t%ALT{0}\n", + ["%CHROM", "\t", "%POS", "\t", "%REF", "\t", ["%ALT", 0], "\n"], + ), + ( + r"%CHROM\t%POS0\t%END\t%ID\n", + ["%CHROM", "\t", "%POS0", "\t", "%END", "\t", "%ID", "\n"], + ), + (r"%CHROM:%POS\n", ["%CHROM", ":", "%POS", "\n"]), + (r"%AC{1}\n", [["%AC", 1], "\n"]), + ( + r"Read depth: %INFO/DP\n", + ["Read", " ", "depth:", " ", "%INFO/DP", "\n"], + ), + ], + ) + def test_valid_expressions(self, parser, expression, expected_result): + assert parser(expression).as_list() == expected_result + + @pytest.mark.parametrize( + "expression", + [ + "%ac", + "%AC {1}", + "% CHROM", + ], + ) + def test_invalid_expressions(self, parser, expression): + with pytest.raises(pp.ParseException): + parser(expression) + + +class TestQueryFormatEvaluator: + @pytest.fixture() + def root(self): + vcf_path = pathlib.Path("tests/data/vcf/sample.vcf.gz") + vcz_path = vcz_path_cache(vcf_path) + return zarr.open(vcz_path, mode="r") + + @pytest.mark.parametrize( + ("query_format", "expected_result"), + [ + (r"A\t", "A\t" * 9), + (r"CHROM", "CHROM" * 9), + ( + r"%CHROM:%POS\n", + "19:111\n19:112\n20:14370\n20:17330\n20:1110696\n20:1230237\n20:1234567\n20:1235237\nX:10\n", + ), + (r"%INFO/DP\n", ".\n.\n14\n11\n10\n13\n9\n.\n.\n"), + (r"%AC\n", ".\n.\n.\n.\n.\n.\n3,1\n.\n.\n"), + (r"%AC{0}\n", ".\n.\n.\n.\n.\n.\n3\n.\n.\n"), + ], + ) + def test(self, root, query_format, expected_result): + generator = QueryFormatGenerator(query_format) + result = "".join(generator(root)) + assert result == expected_result + + @pytest.mark.parametrize( + ("query_format", "expected_result"), + [(r"%QUAL\n", "9.6\n10\n29\n3\n67\n47\n50\n.\n10\n")], + ) + def test_with_parse_results(self, root, query_format, expected_result): + parser = QueryFormatParser() + parse_results = parser(query_format) + generator = QueryFormatGenerator(parse_results) + result = "".join(generator(root)) + assert result == expected_result diff --git a/vcztools/cli.py b/vcztools/cli.py index 873880a..657aa1c 100644 --- a/vcztools/cli.py +++ b/vcztools/cli.py @@ -38,11 +38,14 @@ def index(path, nrecords): is_flag=True, help="List the sample IDs and exit.", ) -def query(path, list_samples): +@click.option("-f", "--format", type=str, help="The format of the output.") +def query(path, list_samples, format): if list_samples: query_module.list_samples(path) return + query_module.write_query(path, query_format=format) + @click.command @click.argument("path", type=click.Path()) diff --git a/vcztools/query.py b/vcztools/query.py index e17cc0e..3a59d54 100644 --- a/vcztools/query.py +++ b/vcztools/query.py @@ -1,6 +1,13 @@ +import functools +import itertools +import math +from typing import Callable, Union + +import numpy as np +import pyparsing as pp import zarr -from vcztools.utils import open_file_like +from vcztools.utils import open_file_like, vcf_name_to_vcz_name def list_samples(vcz_path, output=None): @@ -9,3 +16,164 @@ def list_samples(vcz_path, output=None): with open_file_like(output) as output: sample_ids = root["sample_id"][:] print("\n".join(sample_ids), file=output) + + +class QueryFormatParser: + def __init__(self): + info_tag_pattern = pp.Combine( + pp.Literal("%INFO/") + pp.Word(pp.srange("[A-Z]")) + ) + tag_pattern = info_tag_pattern | pp.Combine( + pp.Literal("%") + pp.Regex(r"[A-Z]+\d?") + ) + subscript_pattern = pp.Group( + tag_pattern + + pp.Literal("{").suppress() + + pp.common.integer + + pp.Literal("}").suppress() + ) + newline_pattern = pp.Literal("\\n").set_parse_action(pp.replace_with("\n")) + tab_pattern = pp.Literal("\\t").set_parse_action(pp.replace_with("\t")) + pattern = pp.ZeroOrMore( + subscript_pattern + | tag_pattern + | newline_pattern + | tab_pattern + | pp.White() + | pp.Word(pp.printables, exclude_chars=r"\{}[]%") + ) + pattern = pattern.leave_whitespace() + + self._parser = functools.partial(pattern.parse_string, parse_all=True) + + def __call__(self, *args, **kwargs): + assert len(args) == 1 + assert not kwargs + + return self._parser(args[0]) + + +class QueryFormatGenerator: + def __init__(self, query_format: Union[str, pp.ParseResults]): + if isinstance(query_format, str): + parser = QueryFormatParser() + parse_results = parser(query_format) + else: + assert isinstance(query_format, pp.ParseResults) + parse_results = query_format + + self._generator = self._compose_generator(parse_results) + + def __call__(self, *args, **kwargs): + assert len(args) == 1 + assert not kwargs + + yield from self._generator(args[0]) + + def _compose_tag_generator(self, tag: str, *, subscript=False) -> Callable: + assert tag.startswith("%") + tag = tag[1:] + + def generate(root): + vcz_names = set(name for name, _zarray in root.items()) + vcz_name = vcf_name_to_vcz_name(vcz_names, tag) + zarray = root[vcz_name] + contig_ids = root["contig_id"][:] if tag == "CHROM" else None + filter_ids = root["filter_id"][:] if tag == "FILTER" else None + v_chunk_size = zarray.chunks[0] + + for v_chunk_index in range(zarray.cdata_shape[0]): + start = v_chunk_index * v_chunk_size + end = start + v_chunk_size + + for row in zarray[start:end]: + is_missing = np.any(row == -1) + + if tag == "CHROM": + assert contig_ids is not None + row = contig_ids[row] + if tag == "REF": + row = row[0] + if tag == "ALT": + row = [allele for allele in row[1:] if allele] + row = row or "." + if tag == "FILTER": + assert filter_ids is not None + + if np.any(row): + row = filter_ids[row] + else: + row = "." + if tag == "QUAL": + if math.isnan(row): + row = "." + else: + row = f"{row:g}" + if not subscript and ( + isinstance(row, np.ndarray) or isinstance(row, list) + ): + row = ",".join(map(str, row)) + + yield row if not is_missing else "." + + return generate + + def _compose_subscript_generator(self, parse_results: pp.ParseResults) -> Callable: + assert len(parse_results) == 2 + + tag, subscript_index = parse_results + tag_generator = self._compose_tag_generator(tag, subscript=True) + + def generate(root): + for tag in tag_generator(root): + if isinstance(tag, str): + assert tag == "." + yield "." + else: + if subscript_index < len(tag): + yield tag[subscript_index] + else: + yield "." + + return generate + + def _compose_element_generator( + self, element: Union[str, pp.ParseResults] + ) -> Callable: + if isinstance(element, pp.ParseResults): + return self._compose_subscript_generator(element) + + assert isinstance(element, str) + + if element.startswith("%"): + return self._compose_tag_generator(element) + else: + + def generate(root): + variant_count = root["variant_position"].shape[0] + yield from itertools.repeat(element, variant_count) + + return generate + + def _compose_generator(self, parse_results: pp.ParseResults) -> Callable: + generators = ( + self._compose_element_generator(element) for element in parse_results + ) + + def generate(root) -> str: + iterables = (generator(root) for generator in generators) + + for results in zip(*iterables): + results = map(str, results) + yield "".join(results) + + return generate + + +def write_query(vcz, output=None, *, query_format: str): + root = zarr.open(vcz, mode="r") + generator = QueryFormatGenerator(query_format) + + with open_file_like(output) as output: + for result in generator(root): + print(result, sep="", end="", file=output)