diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index 2ed90e8..830669c 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -127,7 +127,13 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file): (r"query -f '%FILTER\n'", "sample.vcf.gz"), (r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"), (r"query -f '%POS\n' -i 'POS=112'", "sample.vcf.gz"), - (r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz") + (r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz"), + (r"query -f '[%CHROM\t]\n'", "sample.vcf.gz"), + (r"query -f '[%CHROM\t]\n' -i 'POS=112'", "sample.vcf.gz"), + (r"query -f '%CHROM\t%POS\t%REF\t%ALT[\t%GT]\n'", "sample.vcf.gz"), + (r"query -f 'GQ:[ %GQ] \t GT:[ %GT]\n'", "sample.vcf.gz"), + (r"query -f '[%CHROM:%POS %GT\n]'", "sample.vcf.gz"), + (r"query -f '[%GT %DP\n]'", "sample.vcf.gz"), ], ) def test_output(tmp_path, args, vcf_name): diff --git a/tests/test_query.py b/tests/test_query.py index 97f5bce..91b62a8 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -53,6 +53,42 @@ def parser(self): r"Read depth: %INFO/DP\n", ["Read", " ", "depth:", " ", "%INFO/DP", "\n"], ), + ( + r"%CHROM\t%POS\t%REF\t%ALT[\t%SAMPLE=%GT]\n", + [ + "%CHROM", + "\t", + "%POS", + "\t", + "%REF", + "\t", + "%ALT", + ["\t", "%SAMPLE", "=", "%GT"], + "\n", + ], + ), + ( + r"%CHROM\t%POS\t%REF\t%ALT[\t%SAMPLE=%GT{0}]\n", + [ + "%CHROM", + "\t", + "%POS", + "\t", + "%REF", + "\t", + "%ALT", + ["\t", "%SAMPLE", "=", ["%GT", 0]], + "\n", + ], + ), + ( + r"GQ:[ %GQ] \t GT:[ %GT]\n", + ["GQ:", [" ", "%GQ"], " ", "\t", " ", "GT:", [" ", "%GT"], "\n"], + ), + ( + r"[%SAMPLE %GT %DP\n]", + [["%SAMPLE", " ", "%GT", " ", "%DP", "\n"]], + ), ], ) def test_valid_expressions(self, parser, expression, expected_result): diff --git a/vcztools/query.py b/vcztools/query.py index a1e40eb..4df2b43 100644 --- a/vcztools/query.py +++ b/vcztools/query.py @@ -7,6 +7,7 @@ import pyparsing as pp import zarr +from vcztools import constants from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser from vcztools.utils import open_file_like, vcf_name_to_vcz_name @@ -27,25 +28,29 @@ def __init__(self): tag_pattern = info_tag_pattern | pp.Combine( pp.Literal("%") + pp.Regex(r"[A-Z]+\d?") ) - subscript_pattern = pp.Group( + subfield_pattern = pp.Group( tag_pattern + pp.Literal("{").suppress() + pp.common.integer + pp.Literal("}").suppress() - ) + ).set_results_name("subfield") newline_pattern = pp.Literal("\\n").set_parse_action(pp.replace_with("\n")) tab_pattern = pp.Literal("\\t").set_parse_action(pp.replace_with("\t")) - pattern = pp.ZeroOrMore( - subscript_pattern + format_pattern = pp.Forward() + sample_loop_pattern = pp.Group( + pp.Literal("[").suppress() + format_pattern + pp.Literal("]").suppress() + ).set_results_name("sample loop") + format_pattern <<= pp.ZeroOrMore( + sample_loop_pattern + | subfield_pattern | tag_pattern | newline_pattern | tab_pattern | pp.White() | pp.Word(pp.printables, exclude_chars=r"\{}[]%") - ) - pattern = pattern.leave_whitespace() + ).leave_whitespace() - self._parser = functools.partial(pattern.parse_string, parse_all=True) + self._parser = functools.partial(format_pattern.parse_string, parse_all=True) def __call__(self, *args, **kwargs): assert len(args) == 1 @@ -79,10 +84,51 @@ def __call__(self, *args, **kwargs): yield from self._generator(args[0]) - def _compose_tag_generator(self, tag: str, *, subscript=False) -> Callable: + def _compose_gt_generator(self) -> Callable: + def generate(root): + gt_zarray = root["call_genotype"] + v_chunk_size = gt_zarray.chunks[0] + + if "call_genotype_phased" in root: + phase_zarray = root["call_genotype_phased"] + assert gt_zarray.chunks[:2] == phase_zarray.chunks + assert gt_zarray.shape[:2] == phase_zarray.shape + + for v_chunk_index in range(gt_zarray.cdata_shape[0]): + start = v_chunk_index * v_chunk_size + end = start + v_chunk_size + + for gt_row, phase in zip( + gt_zarray[start:end], phase_zarray[start:end] + ): + + def stringify(gt_and_phase: tuple): + gt, phase = gt_and_phase + gt = [ + str(allele) if allele != constants.INT_MISSING else "." + for allele in gt + if allele != constants.INT_FILL + ] + separator = "|" if phase else "/" + return separator.join(gt) + + gt_row = gt_row.tolist() + yield map(stringify, zip(gt_row, phase)) + else: + # TODO: Support datasets without the phasing data + raise NotImplementedError + + return generate + + def _compose_tag_generator( + self, tag: str, *, subfield=False, sample_loop=False + ) -> Callable: assert tag.startswith("%") tag = tag[1:] + if tag == "GT": + return self._compose_gt_generator() + def generate(root): vcz_names = set(name for name, _zarray in root.items()) vcz_name = vcf_name_to_vcz_name(vcz_names, tag) @@ -104,8 +150,7 @@ def generate(root): if tag == "REF": row = row[0] if tag == "ALT": - row = [allele for allele in row[1:] if allele] - row = row or "." + row = [allele for allele in row[1:] if allele] or "." if tag == "FILTER": assert filter_ids is not None @@ -118,20 +163,38 @@ def generate(root): row = "." else: row = f"{row:g}" - if not subscript and ( - isinstance(row, np.ndarray) or isinstance(row, list) + if ( + not subfield + and not sample_loop + and (isinstance(row, np.ndarray) or isinstance(row, list)) ): row = ",".join(map(str, row)) - yield row if not is_missing else "." + if sample_loop: + sample_count = root["sample_id"].shape[0] + + if isinstance(row, np.ndarray): + row = row.tolist() + row = [ + str(element) + if element != constants.INT_MISSING + else "." + for element in row + if element != constants.INT_FILL + ] + yield row + else: + yield itertools.repeat(str(row), sample_count) + else: + yield row if not is_missing else "." return generate - def _compose_subscript_generator(self, parse_results: pp.ParseResults) -> Callable: + def _compose_subfield_generator(self, parse_results: pp.ParseResults) -> Callable: assert len(parse_results) == 2 - tag, subscript_index = parse_results - tag_generator = self._compose_tag_generator(tag, subscript=True) + tag, subfield_index = parse_results + tag_generator = self._compose_tag_generator(tag, subfield=True) def generate(root): for tag in tag_generator(root): @@ -139,28 +202,61 @@ def generate(root): assert tag == "." yield "." else: - if subscript_index < len(tag): - yield tag[subscript_index] + if subfield_index < len(tag): + yield tag[subfield_index] else: yield "." return generate + def _compose_sample_loop_generator( + self, parse_results: pp.ParseResults + ) -> Callable: + generators = map( + functools.partial(self._compose_element_generator, sample_loop=True), + parse_results, + ) + + def generate(root): + iterables = (generator(root) for generator in generators) + zipped = zip(*iterables) + zipped_zipped = (zip(*element) for element in zipped) + flattened_zipped_zipped = ( + ( + subsubelement + for subelement in element # sample-wise + for subsubelement in subelement + ) + for element in zipped_zipped # variant-wise + ) + yield from map("".join, flattened_zipped_zipped) + + return generate + def _compose_element_generator( - self, element: Union[str, pp.ParseResults] + self, element: Union[str, pp.ParseResults], *, sample_loop=False ) -> Callable: if isinstance(element, pp.ParseResults): - return self._compose_subscript_generator(element) + if element.get_name() == "subfield": + return self._compose_subfield_generator(element) + elif element.get_name() == "sample loop": + return self._compose_sample_loop_generator(element) assert isinstance(element, str) if element.startswith("%"): - return self._compose_tag_generator(element) + return self._compose_tag_generator(element, sample_loop=sample_loop) else: def generate(root): + nonlocal element variant_count = root["variant_position"].shape[0] - yield from itertools.repeat(element, variant_count) + if sample_loop: + sample_count = root["sample_id"].shape[0] + for _ in range(variant_count): + yield itertools.repeat(element, sample_count) + else: + yield from itertools.repeat(element, variant_count) return generate