Skip to content

Commit

Permalink
Add basic support for '--targets/-t'
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jul 11, 2024
1 parent a0b5342 commit 601f80f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 14 deletions.
12 changes: 12 additions & 0 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional
import pytest
from vcztools.regions import parse_targets_string

@pytest.mark.parametrize(
"targets, expected",
[
("chr1:12-103", ("chr1", 12, 103)),
],
)
def test_parse_targets_string(targets: str, expected: tuple[str, Optional[int], Optional[int]]):
assert parse_targets_string(targets) == expected
11 changes: 9 additions & 2 deletions vcztools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@ def list_commands(self, ctx):
@click.command
@click.argument("path", type=click.Path())
@click.option("-c", is_flag=True, default=False, help="Use C implementation")
def view(path, c):
@click.option(
"-t",
"--targets",
type=str,
default=None,
help="Target regions to include.",
)
def view(path, c, targets):
implementation = "c" if c else "numba"
vcf_writer.write_vcf(path, sys.stdout, implementation=implementation)
vcf_writer.write_vcf(path, sys.stdout, variant_targets=targets, implementation=implementation)


@click.group(cls=NaturalOrderGroup, name="vcztools")
Expand Down
42 changes: 42 additions & 0 deletions vcztools/regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import re
from typing import Any, List, Optional

import numpy as np

def parse_targets_string(targets: str) -> tuple[str, Optional[int], Optional[int]]:
"""Return the contig, start position and end position from a targets string."""
if re.search(r":\d+-\d*$", targets):
contig, start_end = targets.rsplit(":", 1)
start, end = start_end.split("-")
return contig, int(start), int(end)
raise NotImplementedError()


def pslice_to_slice(
all_contigs: List[str],
variant_contig: Any,
variant_position: Any,
contig: str,
start: Optional[int] = None,
end: Optional[int] = None,
) -> slice:

contig_index = all_contigs.index(contig)
contig_range = np.searchsorted(variant_contig, [contig_index, contig_index + 1])

if start is None and end is None:
start_index, end_index = contig_range
else:
contig_pos = variant_position[slice(contig_range[0], contig_range[1])]
if start is None:
start_index = contig_range[0]
end_index = contig_range[0] + np.searchsorted(contig_pos, [end])[0]
elif end is None:
start_index = contig_range[0] + np.searchsorted(contig_pos, [start])[0]
end_index = contig_range[1]
else:
start_index, end_index = contig_range[0] + np.searchsorted(
contig_pos, [start, end]
)

return slice(start_index, end_index)
42 changes: 30 additions & 12 deletions vcztools/vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import MutableMapping, Optional, TextIO, Union

import numpy as np
from vcztools.regions import parse_targets_string, pslice_to_slice
import zarr

from . import _vcztools
Expand Down Expand Up @@ -80,7 +81,7 @@ def dims(arr):


def write_vcf(
vcz, output, *, vcf_header: Optional[str] = None, implementation="numba"
vcz, output, *, vcf_header: Optional[str] = None, variant_targets=None, implementation="numba"
) -> None:
"""Convert a dataset to a VCF file.
Expand Down Expand Up @@ -163,7 +164,19 @@ def write_vcf(
contigs = root["contig_id"][:].astype("S")
filters = root["filter_id"][:].astype("S")

if variant_targets is None:
variant_mask = np.ones(pos.shape[0], dtype=bool)
else:
contig, start, end = parse_targets_string(variant_targets)
variant_slice = pslice_to_slice(root["contig_id"][:].astype("U").tolist(), root["variant_contig"], pos, contig, start, end)
variant_mask = np.zeros(pos.shape[0], dtype=bool)
variant_mask[variant_slice] = 1
# Use zarr arrays to get mask chunks aligned with the main data
# for convenience.
z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0])

for v_chunk in range(pos.cdata_shape[0]):
v_mask_chunk = z_variant_mask.blocks[v_chunk]
if implementation == "numba":
numba_chunk_to_vcf(
root,
Expand All @@ -178,22 +191,27 @@ def write_vcf(
c_chunk_to_vcf(
root,
v_chunk,
v_mask_chunk,
contigs,
filters,
output,
)


def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):
chrom = contigs[root.variant_contig.blocks[v_chunk]]
def get_block_selection(zarray, key, mask):
return zarray.blocks[key][mask]


def c_chunk_to_vcf(root, v_chunk, v_mask_chunk, contigs, filters, output):
chrom = contigs[get_block_selection(root.variant_contig, v_chunk, v_mask_chunk)]
# TODO check we don't truncate silently by doing this
pos = root.variant_position.blocks[v_chunk].astype(np.int32)
id = root.variant_id.blocks[v_chunk].astype("S")
alleles = root.variant_allele.blocks[v_chunk]
pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk).astype(np.int32)
id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S")
alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk)
ref = alleles[:, 0].astype("S")
alt = alleles[:, 1:].astype("S")
qual = root.variant_quality.blocks[v_chunk]
filter_ = root.variant_filter.blocks[v_chunk]
qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk)
filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk)

num_variants = len(pos)
if len(id.shape) == 1:
Expand All @@ -207,21 +225,21 @@ def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):
for name, array in root.items():
if name.startswith("call_") and not name.startswith("call_genotype"):
vcf_name = name[len("call_") :]
format_fields[vcf_name] = array.blocks[v_chunk]
format_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk)
if num_samples is None:
num_samples = array.shape[1]
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
vcf_name = name[len("variant_") :]
info_fields[vcf_name] = array.blocks[v_chunk]
info_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk)

gt = None
gt_phased = None
if "call_genotype" in root:
array = root["call_genotype"]
gt = array.blocks[v_chunk]
gt = get_block_selection(array, v_chunk, v_mask_chunk)
if "call_genotype_phased" in root:
array = root["call_genotype_phased"]
gt_phased = array.blocks[v_chunk]
gt_phased = get_block_selection(array, v_chunk, v_mask_chunk)
else:
gt_phased = np.zeros_like(gt, dtype=bool)

Expand Down

0 comments on commit 601f80f

Please sign in to comment.