Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for '--targets/-t' #18

Merged
merged 6 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"numba",
"zarr>=2.17,<3",
"click",
"pyranges",
]
requires-python = ">=3.9"
dynamic = ["version"]
Expand Down
18 changes: 18 additions & 0 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Optional
import pytest
from vcztools.regions import parse_region


@pytest.mark.parametrize(
"targets, expected",
[
("chr1", ("chr1", None, None)),
("chr1:12", ("chr1", 12, 12)),
("chr1:12-", ("chr1", 12, None)),
("chr1:12-103", ("chr1", 12, 103)),
],
)
def test_parse_region(
targets: str, expected: tuple[str, Optional[int], Optional[int]]
):
assert parse_region(targets) == expected
25 changes: 25 additions & 0 deletions tests/test_vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,31 @@ def test_write_vcf(shared_datadir, tmp_path, output_is_path, implementation):
assert_vcfs_close(path, output)


@pytest.mark.parametrize("implementation", ["c", "numba"])
def test_write_vcf__targets(shared_datadir, tmp_path, implementation):
path = shared_datadir / "vcf" / "sample.vcf.gz"
intermediate_icf = tmp_path.joinpath("intermediate.icf")
intermediate_vcz = tmp_path.joinpath("intermediate.vcz")
output = tmp_path.joinpath("output.vcf")

vcf2zarr.convert(
[path], intermediate_vcz, icf_path=intermediate_icf, worker_processes=0
)

write_vcf(intermediate_vcz, output, variant_targets="20", implementation=implementation)

v = VCF(output)

assert v.samples == ["NA00001", "NA00002", "NA00003"]

count = 0
for variant in v:
assert variant.CHROM == "20"
count += 1

assert count == 6


def test_write_vcf__set_header(shared_datadir, tmp_path):
path = shared_datadir / "vcf" / "sample.vcf.gz"
intermediate_icf = tmp_path.joinpath("intermediate.icf")
Expand Down
11 changes: 9 additions & 2 deletions vcztools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@ def list_commands(self, ctx):
@click.command
@click.argument("path", type=click.Path())
@click.option("-c", is_flag=True, default=False, help="Use C implementation")
def view(path, c):
@click.option(
"-t",
"--targets",
type=str,
default=None,
help="Target regions to include.",
)
def view(path, c, targets):
implementation = "c" if c else "numba"
vcf_writer.write_vcf(path, sys.stdout, implementation=implementation)
vcf_writer.write_vcf(path, sys.stdout, variant_targets=targets, implementation=implementation)


@click.group(cls=NaturalOrderGroup, name="vcztools")
Expand Down
60 changes: 60 additions & 0 deletions vcztools/regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import re
from typing import Any, List, Optional

import numpy as np
import pandas as pd
import pyranges


def parse_region(region: str) -> tuple[str, Optional[int], Optional[int]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would return a Region object, like the one in bio2zarr, but I'm not sure which way the dependency is between the two projects (if indeed there is one). So I left it as a tuple for now.

"""Return the contig, start position and end position from a region string."""
if re.search(r":\d+-\d*$", region):
contig, start_end = region.rsplit(":", 1)
start, end = start_end.split("-")
return contig, int(start), int(end) if len(end) > 0 else None
elif re.search(r":\d+$", region):
contig, start = region.rsplit(":", 1)
return contig, int(start), int(start)
else:
contig = region
return contig, None, None


def parse_targets(targets: str) -> list[tuple[str, Optional[int], Optional[int]]]:
return [parse_region(region) for region in targets.split(",")]


def regions_to_selection(
all_contigs: List[str],
variant_contig: Any,
variant_position: Any,
regions: list[tuple[str, Optional[int], Optional[int]]],
):
# subtract 1 from start coordinate to convert intervals
# from VCF (1-based, fully-closed) to Python (0-based, half-open)

df = pd.DataFrame({"Chromosome": variant_contig, "Start": variant_position - 1, "End": variant_position})
# save original index as column so we can retrieve it after finding overlap
df["index"] = df.index
variants = pyranges.PyRanges(df)

chromosomes = []
starts = []
ends = []
for contig, start, end in regions:
if start is None:
start = 0
else:
start -= 1

if end is None:
end = np.iinfo(np.int64).max

chromosomes.append(all_contigs.index(contig))
starts.append(start)
ends.append(end)

query = pyranges.PyRanges(chromosomes=chromosomes, starts=starts, ends=ends)

overlap = variants.overlap(query)
return overlap.df["index"].to_numpy()
84 changes: 55 additions & 29 deletions vcztools/vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import MutableMapping, Optional, TextIO, Union

import numpy as np
from vcztools.regions import parse_targets, regions_to_selection
import zarr

from . import _vcztools
Expand Down Expand Up @@ -80,7 +81,7 @@ def dims(arr):


def write_vcf(
vcz, output, *, vcf_header: Optional[str] = None, implementation="numba"
vcz, output, *, vcf_header: Optional[str] = None, variant_targets=None, implementation="numba"
) -> None:
"""Convert a dataset to a VCF file.

Expand Down Expand Up @@ -163,37 +164,62 @@ def write_vcf(
contigs = root["contig_id"][:].astype("S")
filters = root["filter_id"][:].astype("S")

if variant_targets is None:
variant_mask = np.ones(pos.shape[0], dtype=bool)
else:
regions = parse_targets(variant_targets)
variant_selection = regions_to_selection(
root["contig_id"][:].astype("U").tolist(),
root["variant_contig"],
pos[:],
regions,
)
variant_mask = np.zeros(pos.shape[0], dtype=bool)
variant_mask[variant_selection] = 1
# Use zarr arrays to get mask chunks aligned with the main data
# for convenience.
z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0])

for v_chunk in range(pos.cdata_shape[0]):
v_mask_chunk = z_variant_mask.blocks[v_chunk]
if implementation == "numba":
numba_chunk_to_vcf(
root,
v_chunk,
v_mask_chunk,
header_info_fields,
header_format_fields,
contigs,
filters,
output,
)
else:
c_chunk_to_vcf(
root,
v_chunk,
contigs,
filters,
output,
)
count = np.sum(v_mask_chunk)
if count > 0:
c_chunk_to_vcf(
root,
v_chunk,
v_mask_chunk,
contigs,
filters,
output,
)


def get_block_selection(zarray, key, mask):
return zarray.blocks[key][mask]


def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):
chrom = contigs[root.variant_contig.blocks[v_chunk]]
def c_chunk_to_vcf(root, v_chunk, v_mask_chunk, contigs, filters, output):
chrom = contigs[get_block_selection(root.variant_contig, v_chunk, v_mask_chunk)]
# TODO check we don't truncate silently by doing this
pos = root.variant_position.blocks[v_chunk].astype(np.int32)
id = root.variant_id.blocks[v_chunk].astype("S")
alleles = root.variant_allele.blocks[v_chunk]
pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk).astype(np.int32)
id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S")
alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk)
ref = alleles[:, 0].astype("S")
alt = alleles[:, 1:].astype("S")
qual = root.variant_quality.blocks[v_chunk]
filter_ = root.variant_filter.blocks[v_chunk]
qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk)
filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk)

num_variants = len(pos)
if len(id.shape) == 1:
Expand All @@ -207,21 +233,21 @@ def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):
for name, array in root.items():
if name.startswith("call_") and not name.startswith("call_genotype"):
vcf_name = name[len("call_") :]
format_fields[vcf_name] = array.blocks[v_chunk]
format_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk)
if num_samples is None:
num_samples = array.shape[1]
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
vcf_name = name[len("variant_") :]
info_fields[vcf_name] = array.blocks[v_chunk]
info_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk)

gt = None
gt_phased = None
if "call_genotype" in root:
array = root["call_genotype"]
gt = array.blocks[v_chunk]
gt = get_block_selection(array, v_chunk, v_mask_chunk)
if "call_genotype_phased" in root:
array = root["call_genotype_phased"]
gt_phased = array.blocks[v_chunk]
gt_phased = get_block_selection(array, v_chunk, v_mask_chunk)
else:
gt_phased = np.zeros_like(gt, dtype=bool)

Expand Down Expand Up @@ -269,16 +295,16 @@ def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):


def numba_chunk_to_vcf(
root, v_chunk, header_info_fields, header_format_fields, contigs, filters, output
root, v_chunk, v_mask_chunk, header_info_fields, header_format_fields, contigs, filters, output
):
# fixed fields

chrom = root.variant_contig.blocks[v_chunk]
pos = root.variant_position.blocks[v_chunk]
id = root.variant_id.blocks[v_chunk].astype("S")
alleles = root.variant_allele.blocks[v_chunk].astype("S")
qual = root.variant_quality.blocks[v_chunk]
filter_ = root.variant_filter.blocks[v_chunk]
chrom = get_block_selection(root.variant_contig, v_chunk, v_mask_chunk)
pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk)
id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S")
alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk).astype("S")
qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk)
filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk)

n_variants = len(pos)

Expand All @@ -300,7 +326,7 @@ def numba_chunk_to_vcf(
# not the other way around. This is probably not what we want to
# do, but keeping it this way to preserve tests initially.
continue
values = arr.blocks[v_chunk]
values = get_block_selection(arr, v_chunk, v_mask_chunk)
if arr.dtype == bool:
info_mask[k] = create_mask(values)
info_bufs.append(np.zeros(0, dtype=np.uint8))
Expand Down Expand Up @@ -339,7 +365,7 @@ def numba_chunk_to_vcf(
var = "call_genotype" if key == "GT" else f"call_{key}"
if var not in root:
continue
values = root[var].blocks[v_chunk]
values = get_block_selection(root[var], v_chunk, v_mask_chunk)
if key == "GT":
n_samples = values.shape[1]
format_mask[k] = create_mask(values)
Expand Down Expand Up @@ -367,7 +393,7 @@ def numba_chunk_to_vcf(
format_indexes = np.empty((len(format_values), n_samples + 1), dtype=np.int32)

if "call_genotype_phased" in root:
call_genotype_phased = root["call_genotype_phased"].blocks[v_chunk][:]
call_genotype_phased = get_block_selection(root["call_genotype_phased"], v_chunk, v_mask_chunk)[:]
else:
call_genotype_phased = np.full((n_variants, n_samples), False, dtype=bool)

Expand Down