diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index 2ed90e8..a3d742f 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -127,7 +127,9 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file): (r"query -f '%FILTER\n'", "sample.vcf.gz"), (r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"), (r"query -f '%POS\n' -i 'POS=112'", "sample.vcf.gz"), - (r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz") + (r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz"), + (r"query -f '[%CHROM\t]\n'", "sample.vcf.gz"), + (r"query -f '[%CHROM\t]\n' -i 'POS=112'", "sample.vcf.gz"), ], ) def test_output(tmp_path, args, vcf_name): diff --git a/vcztools/query.py b/vcztools/query.py index 6ad0a09..a93dbb1 100644 --- a/vcztools/query.py +++ b/vcztools/query.py @@ -27,21 +27,21 @@ def __init__(self): tag_pattern = info_tag_pattern | pp.Combine( pp.Literal("%") + pp.Regex(r"[A-Z]+\d?") ) - subscript_pattern = pp.Group( + subfield_pattern = pp.Group( tag_pattern + pp.Literal("{").suppress() + pp.common.integer + pp.Literal("}").suppress() - ) + ).set_results_name("subfield") newline_pattern = pp.Literal("\\n").set_parse_action(pp.replace_with("\n")) tab_pattern = pp.Literal("\\t").set_parse_action(pp.replace_with("\t")) format_pattern = pp.Forward() sample_loop_pattern = pp.Group( pp.Literal("[").suppress() + format_pattern + pp.Literal("]").suppress() - ) + ).set_results_name("sample loop") format_pattern <<= pp.ZeroOrMore( sample_loop_pattern - | subscript_pattern + | subfield_pattern | tag_pattern | newline_pattern | tab_pattern @@ -83,7 +83,9 @@ def __call__(self, *args, **kwargs): yield from self._generator(args[0]) - def _compose_tag_generator(self, tag: str, *, subscript=False) -> Callable: + def _compose_tag_generator( + self, tag: str, *, subfield=False, sample_loop=False + ) -> Callable: assert tag.startswith("%") tag = tag[1:] @@ -122,20 +124,26 @@ def generate(root): row = "." else: row = f"{row:g}" - if not subscript and ( + if not subfield and ( isinstance(row, np.ndarray) or isinstance(row, list) ): row = ",".join(map(str, row)) - yield row if not is_missing else "." + result = row if not is_missing else "." + + if sample_loop: + sample_count = root["sample_id"].shape[0] + yield itertools.repeat(row, sample_count) + else: + yield result return generate - def _compose_subscript_generator(self, parse_results: pp.ParseResults) -> Callable: + def _compose_subfield_generator(self, parse_results: pp.ParseResults) -> Callable: assert len(parse_results) == 2 - tag, subscript_index = parse_results - tag_generator = self._compose_tag_generator(tag, subscript=True) + tag, subfield_index = parse_results + tag_generator = self._compose_tag_generator(tag, subfield=True) def generate(root): for tag in tag_generator(root): @@ -143,28 +151,61 @@ def generate(root): assert tag == "." yield "." else: - if subscript_index < len(tag): - yield tag[subscript_index] + if subfield_index < len(tag): + yield tag[subfield_index] else: yield "." return generate + def _compose_sample_loop_generator( + self, parse_results: pp.ParseResults + ) -> Callable: + generators = map( + functools.partial(self._compose_element_generator, sample_loop=True), + parse_results, + ) + + def generate(root): + iterables = (generator(root) for generator in generators) + zipped = zip(*iterables) + zipped_zipped = (zip(*element) for element in zipped) + flattened_zipped_zipped = ( + ( + subsubelement + for subelement in element # sample-wise + for subsubelement in subelement + ) + for element in zipped_zipped # variant-wise + ) + yield from map("".join, flattened_zipped_zipped) + + return generate + def _compose_element_generator( - self, element: Union[str, pp.ParseResults] + self, element: Union[str, pp.ParseResults], *, sample_loop=False ) -> Callable: if isinstance(element, pp.ParseResults): - return self._compose_subscript_generator(element) + if element.get_name() == "subfield": + return self._compose_subfield_generator(element) + elif element.get_name() == "sample loop": + return self._compose_sample_loop_generator(element) assert isinstance(element, str) if element.startswith("%"): - return self._compose_tag_generator(element) + return self._compose_tag_generator(element, sample_loop=sample_loop) else: def generate(root): + nonlocal element variant_count = root["variant_position"].shape[0] - yield from itertools.repeat(element, variant_count) + if sample_loop: + sample_count = root["sample_id"].shape[0] + for _ in range(variant_count): + yield itertools.repeat(element, sample_count) + else: + yield from itertools.repeat(element, variant_count) return generate