Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query variant filtering #82

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

.vscode
vcz_test_cache/
2 changes: 2 additions & 0 deletions tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
(r"query -f '%QUAL\n'", "sample.vcf.gz"),
(r"query -f '%FILTER\n'", "sample.vcf.gz"),
(r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"),
(r"query -f '%POS\n' -i 'POS=112'", "sample.vcf.gz"),
(r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz")
],
)
def test_output(tmp_path, args, vcf_name):
Expand Down
31 changes: 30 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import pathlib
import re
from io import StringIO

import pyparsing as pp
import pytest
import zarr

from tests.utils import vcz_path_cache
from vcztools.query import QueryFormatGenerator, QueryFormatParser, list_samples
from vcztools.query import (
QueryFormatGenerator,
QueryFormatParser,
list_samples,
write_query,
)


def test_list_samples(tmp_path):
Expand Down Expand Up @@ -101,3 +107,26 @@ def test_with_parse_results(self, root, query_format, expected_result):
generator = QueryFormatGenerator(parse_results)
result = "".join(generator(root))
assert result == expected_result


def test_write_query__include_exclude(tmp_path):
original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
vcz = vcz_path_cache(original)
output = tmp_path.joinpath("output.vcf")

query_format = r"%POS\n"
variant_site_filter = "POS > 1"

with pytest.raises(
ValueError,
match=re.escape(
"Cannot handle both an include expression and an exclude expression."
),
):
write_query(
vcz,
output,
query_format=query_format,
include=variant_site_filter,
exclude=variant_site_filter,
)
23 changes: 15 additions & 8 deletions vcztools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from . import regions, vcf_writer
from . import stats as stats_module

include = click.option(
"-i", "--include", type=str, help="Filter expression to include variant sites."
)
exclude = click.option(
"-e", "--exclude", type=str, help="Filter expression to exclude variant sites."
)


class NaturalOrderGroup(click.Group):
"""
Expand Down Expand Up @@ -48,12 +55,16 @@ def index(path, nrecords, stats):
help="List the sample IDs and exit.",
)
@click.option("-f", "--format", type=str, help="The format of the output.")
def query(path, list_samples, format):
@include
@exclude
def query(path, list_samples, format, include, exclude):
if list_samples:
query_module.list_samples(path)
return

query_module.write_query(path, query_format=format)
query_module.write_query(
path, query_format=format, include=include, exclude=exclude
)


@click.command
Expand Down Expand Up @@ -122,12 +133,8 @@ def query(path, list_samples, format):
default=None,
help="Target regions to include.",
)
@click.option(
"-i", "--include", type=str, help="Filter expression to include variant sites."
)
@click.option(
"-e", "--exclude", type=str, help="Filter expression to exclude variant sites."
)
@include
@exclude
def view(
path,
output,
Expand Down
80 changes: 71 additions & 9 deletions vcztools/query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
import itertools
import math
from typing import Callable, Union
from typing import Callable, Optional, Union

import numpy as np
import pyparsing as pp
import zarr

from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser
from vcztools.utils import open_file_like, vcf_name_to_vcz_name


Expand Down Expand Up @@ -54,15 +55,23 @@ def __call__(self, *args, **kwargs):


class QueryFormatGenerator:
def __init__(self, query_format: Union[str, pp.ParseResults]):
def __init__(
self,
query_format: Union[str, pp.ParseResults],
*,
include: Optional[str] = None,
exclude: Optional[str] = None,
):
if isinstance(query_format, str):
parser = QueryFormatParser()
parse_results = parser(query_format)
else:
assert isinstance(query_format, pp.ParseResults)
parse_results = query_format

self._generator = self._compose_generator(parse_results)
self._generator = self._compose_generator(
parse_results, include=include, exclude=exclude
)

def __call__(self, *args, **kwargs):
assert len(args) == 1
Expand Down Expand Up @@ -155,24 +164,77 @@ def generate(root):

return generate

def _compose_generator(self, parse_results: pp.ParseResults) -> Callable:
def _compose_filter_generator(
self, *, include: Optional[str] = None, exclude: Optional[str] = None
) -> Callable:
assert not (include and exclude)

if not include and not exclude:

def generate(root):
variant_count = root["variant_position"].shape[0]
yield from itertools.repeat(True, variant_count)

return generate

parser = FilterExpressionParser()
parse_results = parser(include or exclude)[0]
filter_evaluator = FilterExpressionEvaluator(
parse_results, invert=bool(exclude)
)

def generate(root):
nonlocal filter_evaluator

filter_evaluator = functools.partial(filter_evaluator, root)
variant_chunk_count = root["variant_position"].cdata_shape[0]

for variant_chunk_index in range(variant_chunk_count):
yield from filter_evaluator(variant_chunk_index)

return generate

def _compose_generator(
self,
parse_results: pp.ParseResults,
*,
include: Optional[str] = None,
exclude: Optional[str] = None,
) -> Callable:
generators = (
self._compose_element_generator(element) for element in parse_results
)
filter_generator = self._compose_filter_generator(
include=include, exclude=exclude
)

def generate(root) -> str:
iterables = (generator(root) for generator in generators)
filter_iterable = filter_generator(root)

for results in zip(*iterables):
results = map(str, results)
yield "".join(results)
for results, filter_indicator in zip(zip(*iterables), filter_iterable):
if filter_indicator:
results = map(str, results)
yield "".join(results)

return generate


def write_query(vcz, output=None, *, query_format: str):
def write_query(
vcz,
output=None,
*,
query_format: str,
include: Optional[str] = None,
exclude: Optional[str] = None,
):
if include and exclude:
raise ValueError(
"Cannot handle both an include expression and an exclude expression."
)

root = zarr.open(vcz, mode="r")
generator = QueryFormatGenerator(query_format)
generator = QueryFormatGenerator(query_format, include=include, exclude=exclude)

with open_file_like(output) as output:
for result in generator(root):
Expand Down