Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query samples looping #88

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
(r"query -f '%FILTER\n'", "sample.vcf.gz"),
(r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"),
(r"query -f '%POS\n' -i 'POS=112'", "sample.vcf.gz"),
(r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz")
(r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz"),
(r"query -f '[%CHROM\t]\n'", "sample.vcf.gz"),
(r"query -f '[%CHROM\t]\n' -i 'POS=112'", "sample.vcf.gz"),
(r"query -f '%CHROM\t%POS\t%REF\t%ALT[\t%GT]\n'", "sample.vcf.gz"),
(r"query -f 'GQ:[ %GQ] \t GT:[ %GT]\n'", "sample.vcf.gz"),
(r"query -f '[%CHROM:%POS %GT\n]'", "sample.vcf.gz"),
(r"query -f '[%GT %DP\n]'", "sample.vcf.gz"),
],
)
def test_output(tmp_path, args, vcf_name):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,42 @@ def parser(self):
r"Read depth: %INFO/DP\n",
["Read", " ", "depth:", " ", "%INFO/DP", "\n"],
),
(
r"%CHROM\t%POS\t%REF\t%ALT[\t%SAMPLE=%GT]\n",
[
"%CHROM",
"\t",
"%POS",
"\t",
"%REF",
"\t",
"%ALT",
["\t", "%SAMPLE", "=", "%GT"],
"\n",
],
),
(
r"%CHROM\t%POS\t%REF\t%ALT[\t%SAMPLE=%GT{0}]\n",
[
"%CHROM",
"\t",
"%POS",
"\t",
"%REF",
"\t",
"%ALT",
["\t", "%SAMPLE", "=", ["%GT", 0]],
"\n",
],
),
(
r"GQ:[ %GQ] \t GT:[ %GT]\n",
["GQ:", [" ", "%GQ"], " ", "\t", " ", "GT:", [" ", "%GT"], "\n"],
),
(
r"[%SAMPLE %GT %DP\n]",
[["%SAMPLE", " ", "%GT", " ", "%DP", "\n"]],
),
],
)
def test_valid_expressions(self, parser, expression, expected_result):
Expand Down
140 changes: 118 additions & 22 deletions vcztools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyparsing as pp
import zarr

from vcztools import constants
from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser
from vcztools.utils import open_file_like, vcf_name_to_vcz_name

Expand All @@ -27,25 +28,29 @@ def __init__(self):
tag_pattern = info_tag_pattern | pp.Combine(
pp.Literal("%") + pp.Regex(r"[A-Z]+\d?")
)
subscript_pattern = pp.Group(
subfield_pattern = pp.Group(
tag_pattern
+ pp.Literal("{").suppress()
+ pp.common.integer
+ pp.Literal("}").suppress()
)
).set_results_name("subfield")
newline_pattern = pp.Literal("\\n").set_parse_action(pp.replace_with("\n"))
tab_pattern = pp.Literal("\\t").set_parse_action(pp.replace_with("\t"))
pattern = pp.ZeroOrMore(
subscript_pattern
format_pattern = pp.Forward()
sample_loop_pattern = pp.Group(
pp.Literal("[").suppress() + format_pattern + pp.Literal("]").suppress()
).set_results_name("sample loop")
format_pattern <<= pp.ZeroOrMore(
sample_loop_pattern
| subfield_pattern
| tag_pattern
| newline_pattern
| tab_pattern
| pp.White()
| pp.Word(pp.printables, exclude_chars=r"\{}[]%")
)
pattern = pattern.leave_whitespace()
).leave_whitespace()

self._parser = functools.partial(pattern.parse_string, parse_all=True)
self._parser = functools.partial(format_pattern.parse_string, parse_all=True)

def __call__(self, *args, **kwargs):
assert len(args) == 1
Expand Down Expand Up @@ -79,10 +84,51 @@ def __call__(self, *args, **kwargs):

yield from self._generator(args[0])

def _compose_tag_generator(self, tag: str, *, subscript=False) -> Callable:
def _compose_gt_generator(self) -> Callable:
def generate(root):
gt_zarray = root["call_genotype"]
v_chunk_size = gt_zarray.chunks[0]

if "call_genotype_phased" in root:
phase_zarray = root["call_genotype_phased"]
assert gt_zarray.chunks[:2] == phase_zarray.chunks
assert gt_zarray.shape[:2] == phase_zarray.shape

for v_chunk_index in range(gt_zarray.cdata_shape[0]):
start = v_chunk_index * v_chunk_size
end = start + v_chunk_size

for gt_row, phase in zip(
gt_zarray[start:end], phase_zarray[start:end]
):

def stringify(gt_and_phase: tuple):
gt, phase = gt_and_phase
gt = [
str(allele) if allele != constants.INT_MISSING else "."
for allele in gt
if allele != constants.INT_FILL
]
separator = "|" if phase else "/"
return separator.join(gt)

gt_row = gt_row.tolist()
yield map(stringify, zip(gt_row, phase))
else:
# TODO: Support datasets without the phasing data
raise NotImplementedError

return generate

def _compose_tag_generator(
self, tag: str, *, subfield=False, sample_loop=False
) -> Callable:
assert tag.startswith("%")
tag = tag[1:]

if tag == "GT":
return self._compose_gt_generator()

def generate(root):
vcz_names = set(name for name, _zarray in root.items())
vcz_name = vcf_name_to_vcz_name(vcz_names, tag)
Expand All @@ -104,8 +150,7 @@ def generate(root):
if tag == "REF":
row = row[0]
if tag == "ALT":
row = [allele for allele in row[1:] if allele]
row = row or "."
row = [allele for allele in row[1:] if allele] or "."
if tag == "FILTER":
assert filter_ids is not None

Expand All @@ -118,49 +163,100 @@ def generate(root):
row = "."
else:
row = f"{row:g}"
if not subscript and (
isinstance(row, np.ndarray) or isinstance(row, list)
if (
not subfield
and not sample_loop
and (isinstance(row, np.ndarray) or isinstance(row, list))
):
row = ",".join(map(str, row))

yield row if not is_missing else "."
if sample_loop:
sample_count = root["sample_id"].shape[0]

if isinstance(row, np.ndarray):
row = row.tolist()
row = [
str(element)
if element != constants.INT_MISSING
else "."
for element in row
if element != constants.INT_FILL
]
yield row
else:
yield itertools.repeat(str(row), sample_count)
else:
yield row if not is_missing else "."

return generate

def _compose_subscript_generator(self, parse_results: pp.ParseResults) -> Callable:
def _compose_subfield_generator(self, parse_results: pp.ParseResults) -> Callable:
assert len(parse_results) == 2

tag, subscript_index = parse_results
tag_generator = self._compose_tag_generator(tag, subscript=True)
tag, subfield_index = parse_results
tag_generator = self._compose_tag_generator(tag, subfield=True)

def generate(root):
for tag in tag_generator(root):
if isinstance(tag, str):
assert tag == "."
yield "."
else:
if subscript_index < len(tag):
yield tag[subscript_index]
if subfield_index < len(tag):
yield tag[subfield_index]
else:
yield "."

return generate

def _compose_sample_loop_generator(
self, parse_results: pp.ParseResults
) -> Callable:
generators = map(
functools.partial(self._compose_element_generator, sample_loop=True),
parse_results,
)

def generate(root):
iterables = (generator(root) for generator in generators)
zipped = zip(*iterables)
zipped_zipped = (zip(*element) for element in zipped)
flattened_zipped_zipped = (
(
subsubelement
for subelement in element # sample-wise
for subsubelement in subelement
)
for element in zipped_zipped # variant-wise
)
yield from map("".join, flattened_zipped_zipped)

return generate

def _compose_element_generator(
self, element: Union[str, pp.ParseResults]
self, element: Union[str, pp.ParseResults], *, sample_loop=False
) -> Callable:
if isinstance(element, pp.ParseResults):
return self._compose_subscript_generator(element)
if element.get_name() == "subfield":
return self._compose_subfield_generator(element)
elif element.get_name() == "sample loop":
return self._compose_sample_loop_generator(element)

assert isinstance(element, str)

if element.startswith("%"):
return self._compose_tag_generator(element)
return self._compose_tag_generator(element, sample_loop=sample_loop)
else:

def generate(root):
nonlocal element
variant_count = root["variant_position"].shape[0]
yield from itertools.repeat(element, variant_count)
if sample_loop:
sample_count = root["sample_id"].shape[0]
for _ in range(variant_count):
yield itertools.repeat(element, sample_count)
else:
yield from itertools.repeat(element, variant_count)

return generate

Expand Down