From 96dbe758e7e9d45fc4504e52f06e879d3024e859 Mon Sep 17 00:00:00 2001 From: willtyler Date: Mon, 26 Aug 2024 22:48:20 +0000 Subject: [PATCH] Move filter expression code to filter.py --- tests/test_filter.py | 136 +++++++++++++++++++++++++++++++ tests/test_utils.py | 134 ------------------------------ vcztools/filter.py | 180 +++++++++++++++++++++++++++++++++++++++++ vcztools/utils.py | 176 ---------------------------------------- vcztools/vcf_writer.py | 3 +- 5 files changed, 317 insertions(+), 312 deletions(-) create mode 100644 tests/test_filter.py create mode 100644 vcztools/filter.py diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 0000000..d86b5e8 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,136 @@ +import pathlib + +import numpy as np +import pyparsing as pp +import pytest +import zarr +from numpy.testing import assert_array_equal + +from tests.utils import vcz_path_cache +from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser + + +class TestFilterExpressionParser: + @pytest.fixture() + def identifier_parser(self, parser): + return parser._identifier_parser + + @pytest.fixture() + def parser(self): + return FilterExpressionParser() + + @pytest.mark.parametrize( + ("expression", "expected_result"), + [ + ("1", [1]), + ("1.0", [1.0]), + ("1e-4", [0.0001]), + ('"String"', ["String"]), + ("POS", ["POS"]), + ("INFO/DP", ["INFO/DP"]), + ("FORMAT/GT", ["FORMAT/GT"]), + ("FMT/GT", ["FMT/GT"]), + ("GT", ["GT"]), + ], + ) + def test_valid_identifiers(self, identifier_parser, expression, expected_result): + assert identifier_parser(expression).as_list() == expected_result + + @pytest.mark.parametrize( + "expression", + [ + "", + "FORMAT/ GT", + "format / GT", + "fmt / GT", + "info / DP", + "'String'", + ], + ) + def test_invalid_identifiers(self, identifier_parser, expression): + with pytest.raises(pp.ParseException): + identifier_parser(expression) + + @pytest.mark.parametrize( + ("expression", "expected_result"), + [ + ("POS>=100", [["POS", ">=", 100]]), + ( + "FMT/DP>10 && FMT/GQ>10", + [[["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]]], + ), + ("QUAL>10 || FMT/GQ>10", [[["QUAL", ">", 10], "||", ["FMT/GQ", ">", 10]]]), + ( + "FMT/DP>10 && FMT/GQ>10 || QUAL > 10", + [ + [ + [["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]], + "||", + ["QUAL", ">", 10], + ] + ], + ), + ( + "QUAL>10 || FMT/DP>10 && FMT/GQ>10", + [ + [ + ["QUAL", ">", 10], + "||", + [["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]], + ] + ], + ), + ( + "QUAL>10 | FMT/DP>10 & FMT/GQ>10", + [ + [ + ["QUAL", ">", 10], + "|", + [["FMT/DP", ">", 10], "&", ["FMT/GQ", ">", 10]], + ], + ], + ), + ( + "(QUAL>10 || FMT/DP>10) && FMT/GQ>10", + [ + [ + [["QUAL", ">", 10], "||", ["FMT/DP", ">", 10]], + "&&", + ["FMT/GQ", ">", 10], + ] + ], + ), + ], + ) + def test_valid_expressions(self, parser, expression, expected_result): + assert parser(expression=expression).as_list() == expected_result + + +class TestFilterExpressionEvaluator: + @pytest.mark.parametrize( + ("expression", "expected_result"), + [ + ("POS < 1000", [1, 1, 0, 0, 0, 0, 0, 0, 1]), + ("FMT/GQ > 20", [0, 0, 1, 1, 1, 1, 1, 0, 0]), + ("FMT/DP >= 5 && FMT/GQ > 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]), + ("FMT/DP >= 5 & FMT/GQ>10", [0, 0, 1, 0, 1, 0, 0, 0, 0]), + ("QUAL > 10 || FMT/GQ>10", [0, 0, 1, 1, 1, 1, 1, 0, 0]), + ("(QUAL > 10 || FMT/GQ>10) && POS > 100000", [0, 0, 0, 0, 1, 1, 1, 0, 0]), + ("(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000", [0, 0, 0, 0, 0, 1, 0, 0, 0]), + ("INFO/DP > 10", [0, 0, 1, 1, 0, 1, 0, 0, 0]), + ("GT > 0", [1, 1, 1, 1, 1, 0, 1, 0, 1]), + ("GT > 0 & FMT/HQ >= 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]), + ], + ) + def test(self, expression, expected_result): + original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" + vcz = vcz_path_cache(original) + root = zarr.open(vcz, mode="r") + + parser = FilterExpressionParser() + parse_results = parser(expression)[0] + evaluator = FilterExpressionEvaluator(parse_results) + assert_array_equal(evaluator(root, 0), expected_result) + + invert_evaluator = FilterExpressionEvaluator(parse_results, invert=True) + assert_array_equal(invert_evaluator(root, 0), np.logical_not(expected_result)) diff --git a/tests/test_utils.py b/tests/test_utils.py index e9dc6f5..1909635 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,15 +1,7 @@ -import pathlib - -import numpy as np -import pyparsing as pp import pytest -import zarr from numpy.testing import assert_array_equal -from tests.utils import vcz_path_cache from vcztools.utils import ( - FilterExpressionEvaluator, - FilterExpressionParser, search, vcf_name_to_vcz_name, ) @@ -50,129 +42,3 @@ def test_search(a, v, expected_ind): ) def test_vcf_to_vcz(vczs, vcf, expected_vcz): assert vcf_name_to_vcz_name(vczs, vcf) == expected_vcz - - -class TestFilterExpressionParser: - @pytest.fixture() - def identifier_parser(self): - return FilterExpressionParser()._identifier_parser - - @pytest.fixture() - def parser(self): - return FilterExpressionParser() - - @pytest.mark.parametrize( - ("expression", "expected_result"), - [ - ("1", [1]), - ("1.0", [1.0]), - ("1e-4", [0.0001]), - ('"String"', ["String"]), - ("POS", ["POS"]), - ("INFO/DP", ["INFO/DP"]), - ("FORMAT/GT", ["FORMAT/GT"]), - ("FMT/GT", ["FMT/GT"]), - ("GT", ["GT"]), - ], - ) - def test_valid_identifiers(self, identifier_parser, expression, expected_result): - assert identifier_parser(expression).as_list() == expected_result - - @pytest.mark.parametrize( - "expression", - [ - "", - "FORMAT/ GT", - "format / GT", - "fmt / GT", - "info / DP", - "'String'", - ], - ) - def test_invalid_identifiers(self, identifier_parser, expression): - with pytest.raises(pp.ParseException): - identifier_parser(expression) - - @pytest.mark.parametrize( - ("expression", "expected_result"), - [ - ("POS>=100", [["POS", ">=", 100]]), - ( - "FMT/DP>10 && FMT/GQ>10", - [[["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]]], - ), - ("QUAL>10 || FMT/GQ>10", [[["QUAL", ">", 10], "||", ["FMT/GQ", ">", 10]]]), - ( - "FMT/DP>10 && FMT/GQ>10 || QUAL > 10", - [ - [ - [["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]], - "||", - ["QUAL", ">", 10], - ] - ], - ), - ( - "QUAL>10 || FMT/DP>10 && FMT/GQ>10", - [ - [ - ["QUAL", ">", 10], - "||", - [["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]], - ] - ], - ), - ( - "QUAL>10 | FMT/DP>10 & FMT/GQ>10", - [ - [ - ["QUAL", ">", 10], - "|", - [["FMT/DP", ">", 10], "&", ["FMT/GQ", ">", 10]], - ], - ], - ), - ( - "(QUAL>10 || FMT/DP>10) && FMT/GQ>10", - [ - [ - [["QUAL", ">", 10], "||", ["FMT/DP", ">", 10]], - "&&", - ["FMT/GQ", ">", 10], - ] - ], - ), - ], - ) - def test_valid_expressions(self, parser, expression, expected_result): - assert parser(expression=expression).as_list() == expected_result - - -class TestFilterExpressionEvaluator: - @pytest.mark.parametrize( - ("expression", "expected_result"), - [ - ("POS < 1000", [1, 1, 0, 0, 0, 0, 0, 0, 1]), - ("FMT/GQ > 20", [0, 0, 1, 1, 1, 1, 1, 0, 0]), - ("FMT/DP >= 5 && FMT/GQ > 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]), - ("FMT/DP >= 5 & FMT/GQ>10", [0, 0, 1, 0, 1, 0, 0, 0, 0]), - ("QUAL > 10 || FMT/GQ>10", [0, 0, 1, 1, 1, 1, 1, 0, 0]), - ("(QUAL > 10 || FMT/GQ>10) && POS > 100000", [0, 0, 0, 0, 1, 1, 1, 0, 0]), - ("(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000", [0, 0, 0, 0, 0, 1, 0, 0, 0]), - ("INFO/DP > 10", [0, 0, 1, 1, 0, 1, 0, 0, 0]), - ("GT > 0", [1, 1, 1, 1, 1, 0, 1, 0, 1]), - ("GT > 0 & FMT/HQ >= 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]), - ], - ) - def test(self, expression, expected_result): - original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" - vcz = vcz_path_cache(original) - root = zarr.open(vcz, mode="r") - - parser = FilterExpressionParser() - parse_results = parser(expression)[0] - evaluator = FilterExpressionEvaluator(parse_results) - assert_array_equal(evaluator(root, 0), expected_result) - - invert_evaluator = FilterExpressionEvaluator(parse_results, invert=True) - assert_array_equal(invert_evaluator(root, 0), np.logical_not(expected_result)) diff --git a/vcztools/filter.py b/vcztools/filter.py new file mode 100644 index 0000000..5518920 --- /dev/null +++ b/vcztools/filter.py @@ -0,0 +1,180 @@ +import functools +import operator +from typing import Callable + +import numpy as np +import pyparsing as pp + +from vcztools.utils import vcf_name_to_vcz_name + + +class FilterExpressionParser: + def __init__(self): + constant_pattern = pp.common.number | pp.QuotedString('"') + standard_tag_pattern = pp.Word(pp.srange("[A-Z]")) + + info_tag_pattern = pp.Combine(pp.Literal("INFO/") + standard_tag_pattern) + format_tag_pattern = pp.Combine( + (pp.Literal("FORMAT/") | pp.Literal("FMT/")) + standard_tag_pattern + ) + tag_pattern = info_tag_pattern | format_tag_pattern | standard_tag_pattern + + identifier_pattern = tag_pattern | constant_pattern + self._identifier_parser = functools.partial( + identifier_pattern.parse_string, parse_all=True + ) + + comparison_pattern = pp.Group( + tag_pattern + pp.one_of("== = != > >= < <=") + constant_pattern + ).set_results_name("comparison") + + parentheses_pattern = pp.Forward() + and_pattern = pp.Forward() + or_pattern = pp.Forward() + + parentheses_pattern <<= ( + pp.Suppress("(") + or_pattern + pp.Suppress(")") | comparison_pattern + ) + and_pattern <<= ( + pp.Group( + parentheses_pattern + (pp.Keyword("&&") | pp.Keyword("&")) + and_pattern + ).set_results_name("and") + | parentheses_pattern + ) + or_pattern <<= ( + pp.Group( + and_pattern + (pp.Keyword("||") | pp.Keyword("|")) + or_pattern + ).set_results_name("or") + | and_pattern + ) + + self._parser = functools.partial(or_pattern.parse_string, parse_all=True) + + def __call__(self, *args, **kwargs): + assert args or kwargs + + if args: + assert len(args) == 1 + assert not kwargs + expression = args[0] + else: + assert len(kwargs) == 1 + assert "expression" in kwargs + expression = kwargs["expression"] + + return self.parse(expression) + + def parse(self, expression: str): + return self._parser(expression) + + +class FilterExpressionEvaluator: + def __init__(self, parse_results: pp.ParseResults, *, invert=False): + self._composers = { + "comparison": self._compose_comparison_evaluator, + "and": self._compose_and_evaluator, + "or": self._compose_or_evaluator, + } + self._comparators = { + "==": operator.eq, + "=": operator.eq, + "!=": operator.ne, + ">": operator.gt, + ">=": operator.ge, + "<": operator.lt, + "<=": operator.le, + } + base_evaluator = self._compose_evaluator(parse_results) + + def evaluator(root, variant_chunk_index: int) -> np.ndarray: + base_array = base_evaluator(root, variant_chunk_index) + return np.any(base_array, axis=tuple(range(1, base_array.ndim))) + + if invert: + + def invert_evaluator(root, variant_chunk_index: int) -> np.ndarray: + return np.logical_not(evaluator(root, variant_chunk_index)) + + self._evaluator = invert_evaluator + else: + self._evaluator = evaluator + + def __call__(self, *args, **kwargs): + assert len(args) == 2 + assert not kwargs + + return self._evaluator(*args) + + def _compose_comparison_evaluator(self, parse_results: pp.ParseResults) -> Callable: + assert len(parse_results) == 3 + + comparator = parse_results[1] + comparator = self._comparators[comparator] + + def evaluator(root, variant_chunk_index: int) -> np.ndarray: + vcf_name = parse_results[0] + vcz_names = set(name for name, _array in root.items()) + vcz_name = vcf_name_to_vcz_name(vcz_names, vcf_name) + zarray = root[vcz_name] + variant_chunk_len = zarray.chunks[0] + start = variant_chunk_len * variant_chunk_index + end = start + variant_chunk_len + # We load all samples (regardless of sample filtering) + # to match bcftools' behavior. + array = zarray[start:end] + array = comparator(array, parse_results[2]) + + if array.ndim > 2: + return np.any(array, axis=tuple(range(2, array.ndim))) + else: + return array + + return evaluator + + def _compose_and_evaluator(self, parse_results: pp.ParseResults) -> Callable: + assert len(parse_results) == 3 + assert parse_results[1] in {"&", "&&"} + + left_evaluator = self._compose_evaluator(parse_results[0]) + right_evaluator = self._compose_evaluator(parse_results[2]) + + def evaluator(root, variant_chunk_index): + left_array = left_evaluator(root, variant_chunk_index) + right_array = right_evaluator(root, variant_chunk_index) + + if parse_results[1] == "&": + return np.logical_and(left_array, right_array) + else: + left_array = np.any(left_array, axis=tuple(range(1, left_array.ndim))) + right_array = np.any( + right_array, axis=tuple(range(1, right_array.ndim)) + ) + return np.logical_and(left_array, right_array) + + return evaluator + + def _compose_or_evaluator(self, parse_results: pp.ParseResults) -> Callable: + assert len(parse_results) == 3 + assert parse_results[1] in {"|", "||"} + + left_evaluator = self._compose_evaluator(parse_results[0]) + right_evaluator = self._compose_evaluator(parse_results[2]) + + def evaluator(root, variant_chunk_index: int): + left_array = left_evaluator(root, variant_chunk_index) + right_array = right_evaluator(root, variant_chunk_index) + + if parse_results[1] == "|": + return np.logical_or(left_array, right_array) + else: + left_array = np.any(left_array, axis=tuple(range(1, left_array.ndim))) + right_array = np.any( + right_array, axis=tuple(range(1, right_array.ndim)) + ) + return np.logical_or(left_array, right_array) + + return evaluator + + def _compose_evaluator(self, parse_results: pp.ParseResults) -> Callable: + results_name = parse_results.get_name() + return self._composers[results_name](parse_results) diff --git a/vcztools/utils.py b/vcztools/utils.py index 25389f3..977433c 100644 --- a/vcztools/utils.py +++ b/vcztools/utils.py @@ -1,11 +1,7 @@ -import functools -import operator from contextlib import ExitStack, contextmanager from pathlib import Path -from typing import Callable import numpy as np -import pyparsing as pp from vcztools.constants import RESERVED_VCF_FIELDS @@ -27,66 +23,6 @@ def open_file_like(file): yield file -class FilterExpressionParser: - def __init__(self): - constant_pattern = pp.common.number | pp.QuotedString('"') - standard_tag_pattern = pp.Word(pp.srange("[A-Z]")) - - info_tag_pattern = pp.Combine(pp.Literal("INFO/") + standard_tag_pattern) - format_tag_pattern = pp.Combine( - (pp.Literal("FORMAT/") | pp.Literal("FMT/")) + standard_tag_pattern - ) - tag_pattern = info_tag_pattern | format_tag_pattern | standard_tag_pattern - - identifier_pattern = tag_pattern | constant_pattern - self._identifier_parser = functools.partial( - identifier_pattern.parse_string, parse_all=True - ) - - comparison_pattern = pp.Group( - tag_pattern + pp.one_of("== = != > >= < <=") + constant_pattern - ).set_results_name("comparison") - - parentheses_pattern = pp.Forward() - and_pattern = pp.Forward() - or_pattern = pp.Forward() - - parentheses_pattern <<= ( - pp.Suppress("(") + or_pattern + pp.Suppress(")") | comparison_pattern - ) - and_pattern <<= ( - pp.Group( - parentheses_pattern + (pp.Keyword("&&") | pp.Keyword("&")) + and_pattern - ).set_results_name("and") - | parentheses_pattern - ) - or_pattern <<= ( - pp.Group( - and_pattern + (pp.Keyword("||") | pp.Keyword("|")) + or_pattern - ).set_results_name("or") - | and_pattern - ) - - self._parser = functools.partial(or_pattern.parse_string, parse_all=True) - - def __call__(self, *args, **kwargs): - assert args or kwargs - - if args: - assert len(args) == 1 - assert not kwargs - expression = args[0] - else: - assert len(kwargs) == 1 - assert "expression" in kwargs - expression = kwargs["expression"] - - return self.parse(expression) - - def parse(self, expression: str): - return self._parser(expression) - - def vcf_name_to_vcz_name(vcz_names: set[str], vcf_name: str) -> str: """ Convert the name of a VCF field to the name of the corresponding VCF Zarr array. @@ -115,115 +51,3 @@ def vcf_name_to_vcz_name(vcz_names: set[str], vcf_name: str) -> str: return f"variant_{split[-1]}" else: return RESERVED_VCF_FIELDS[vcf_name] - - -class FilterExpressionEvaluator: - def __init__(self, parse_results: pp.ParseResults, *, invert=False): - self._composers = { - "comparison": self._compose_comparison_evaluator, - "and": self._compose_and_evaluator, - "or": self._compose_or_evaluator, - } - self._comparators = { - "==": operator.eq, - "=": operator.eq, - "!=": operator.ne, - ">": operator.gt, - ">=": operator.ge, - "<": operator.lt, - "<=": operator.le, - } - base_evaluator = self._compose_evaluator(parse_results) - - def evaluator(root, variant_chunk_index: int) -> np.ndarray: - base_array = base_evaluator(root, variant_chunk_index) - return np.any(base_array, axis=tuple(range(1, base_array.ndim))) - - if invert: - - def invert_evaluator(root, variant_chunk_index: int) -> np.ndarray: - return np.logical_not(evaluator(root, variant_chunk_index)) - - self._evaluator = invert_evaluator - else: - self._evaluator = evaluator - - def __call__(self, *args, **kwargs): - assert len(args) == 2 - assert not kwargs - - return self._evaluator(*args) - - def _compose_comparison_evaluator(self, parse_results: pp.ParseResults) -> Callable: - assert len(parse_results) == 3 - - comparator = parse_results[1] - comparator = self._comparators[comparator] - - def evaluator(root, variant_chunk_index: int) -> np.ndarray: - vcf_name = parse_results[0] - vcz_names = set(name for name, _array in root.items()) - vcz_name = vcf_name_to_vcz_name(vcz_names, vcf_name) - zarray = root[vcz_name] - variant_chunk_len = zarray.chunks[0] - start = variant_chunk_len * variant_chunk_index - end = start + variant_chunk_len - # We load all samples (regardless of sample filtering) - # to match bcftools' behavior. - array = zarray[start:end] - array = comparator(array, parse_results[2]) - - if array.ndim > 2: - return np.any(array, axis=tuple(range(2, array.ndim))) - else: - return array - - return evaluator - - def _compose_and_evaluator(self, parse_results: pp.ParseResults) -> Callable: - assert len(parse_results) == 3 - assert parse_results[1] in {"&", "&&"} - - left_evaluator = self._compose_evaluator(parse_results[0]) - right_evaluator = self._compose_evaluator(parse_results[2]) - - def evaluator(root, variant_chunk_index): - left_array = left_evaluator(root, variant_chunk_index) - right_array = right_evaluator(root, variant_chunk_index) - - if parse_results[1] == "&": - return np.logical_and(left_array, right_array) - else: - left_array = np.any(left_array, axis=tuple(range(1, left_array.ndim))) - right_array = np.any( - right_array, axis=tuple(range(1, right_array.ndim)) - ) - return np.logical_and(left_array, right_array) - - return evaluator - - def _compose_or_evaluator(self, parse_results: pp.ParseResults) -> Callable: - assert len(parse_results) == 3 - assert parse_results[1] in {"|", "||"} - - left_evaluator = self._compose_evaluator(parse_results[0]) - right_evaluator = self._compose_evaluator(parse_results[2]) - - def evaluator(root, variant_chunk_index: int): - left_array = left_evaluator(root, variant_chunk_index) - right_array = right_evaluator(root, variant_chunk_index) - - if parse_results[1] == "|": - return np.logical_or(left_array, right_array) - else: - left_array = np.any(left_array, axis=tuple(range(1, left_array.ndim))) - right_array = np.any( - right_array, axis=tuple(range(1, right_array.ndim)) - ) - return np.logical_or(left_array, right_array) - - return evaluator - - def _compose_evaluator(self, parse_results: pp.ParseResults) -> Callable: - results_name = parse_results.get_name() - return self._composers[results_name](parse_results) diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index c1b7bc3..b3df584 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -13,14 +13,13 @@ regions_to_selection, ) from vcztools.utils import ( - FilterExpressionEvaluator, - FilterExpressionParser, open_file_like, search, ) from . import _vcztools from .constants import RESERVED_VARIABLE_NAMES +from .filter import FilterExpressionEvaluator, FilterExpressionParser # references to the VCF spec are for https://samtools.github.io/hts-specs/VCFv4.3.pdf