diff --git a/pyproject.toml b/pyproject.toml index 1b53341..a70cbfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numba", "zarr>=2.17,<3", "click", + "pyranges", ] requires-python = ">=3.9" dynamic = ["version"] diff --git a/tests/test_regions.py b/tests/test_regions.py new file mode 100644 index 0000000..71f076d --- /dev/null +++ b/tests/test_regions.py @@ -0,0 +1,18 @@ +from typing import Optional +import pytest +from vcztools.regions import parse_region + + +@pytest.mark.parametrize( + "targets, expected", + [ + ("chr1", ("chr1", None, None)), + ("chr1:12", ("chr1", 12, 12)), + ("chr1:12-", ("chr1", 12, None)), + ("chr1:12-103", ("chr1", 12, 103)), + ], +) +def test_parse_region( + targets: str, expected: tuple[str, Optional[int], Optional[int]] +): + assert parse_region(targets) == expected diff --git a/tests/test_vcf_writer.py b/tests/test_vcf_writer.py index d577b6f..b89a6f4 100644 --- a/tests/test_vcf_writer.py +++ b/tests/test_vcf_writer.py @@ -55,6 +55,31 @@ def test_write_vcf(shared_datadir, tmp_path, output_is_path, implementation): assert_vcfs_close(path, output) +@pytest.mark.parametrize("implementation", ["c", "numba"]) +def test_write_vcf__targets(shared_datadir, tmp_path, implementation): + path = shared_datadir / "vcf" / "sample.vcf.gz" + intermediate_icf = tmp_path.joinpath("intermediate.icf") + intermediate_vcz = tmp_path.joinpath("intermediate.vcz") + output = tmp_path.joinpath("output.vcf") + + vcf2zarr.convert( + [path], intermediate_vcz, icf_path=intermediate_icf, worker_processes=0 + ) + + write_vcf(intermediate_vcz, output, variant_targets="20", implementation=implementation) + + v = VCF(output) + + assert v.samples == ["NA00001", "NA00002", "NA00003"] + + count = 0 + for variant in v: + assert variant.CHROM == "20" + count += 1 + + assert count == 6 + + def test_write_vcf__set_header(shared_datadir, tmp_path): path = shared_datadir / "vcf" / "sample.vcf.gz" intermediate_icf = tmp_path.joinpath("intermediate.icf") diff --git a/vcztools/cli.py b/vcztools/cli.py index a528ddc..8aa821d 100644 --- a/vcztools/cli.py +++ b/vcztools/cli.py @@ -18,9 +18,16 @@ def list_commands(self, ctx): @click.command @click.argument("path", type=click.Path()) @click.option("-c", is_flag=True, default=False, help="Use C implementation") -def view(path, c): +@click.option( + "-t", + "--targets", + type=str, + default=None, + help="Target regions to include.", +) +def view(path, c, targets): implementation = "c" if c else "numba" - vcf_writer.write_vcf(path, sys.stdout, implementation=implementation) + vcf_writer.write_vcf(path, sys.stdout, variant_targets=targets, implementation=implementation) @click.group(cls=NaturalOrderGroup, name="vcztools") diff --git a/vcztools/regions.py b/vcztools/regions.py new file mode 100644 index 0000000..2ba1d97 --- /dev/null +++ b/vcztools/regions.py @@ -0,0 +1,60 @@ +import re +from typing import Any, List, Optional + +import numpy as np +import pandas as pd +import pyranges + + +def parse_region(region: str) -> tuple[str, Optional[int], Optional[int]]: + """Return the contig, start position and end position from a region string.""" + if re.search(r":\d+-\d*$", region): + contig, start_end = region.rsplit(":", 1) + start, end = start_end.split("-") + return contig, int(start), int(end) if len(end) > 0 else None + elif re.search(r":\d+$", region): + contig, start = region.rsplit(":", 1) + return contig, int(start), int(start) + else: + contig = region + return contig, None, None + + +def parse_targets(targets: str) -> list[tuple[str, Optional[int], Optional[int]]]: + return [parse_region(region) for region in targets.split(",")] + + +def regions_to_selection( + all_contigs: List[str], + variant_contig: Any, + variant_position: Any, + regions: list[tuple[str, Optional[int], Optional[int]]], +): + # subtract 1 from start coordinate to convert intervals + # from VCF (1-based, fully-closed) to Python (0-based, half-open) + + df = pd.DataFrame({"Chromosome": variant_contig, "Start": variant_position - 1, "End": variant_position}) + # save original index as column so we can retrieve it after finding overlap + df["index"] = df.index + variants = pyranges.PyRanges(df) + + chromosomes = [] + starts = [] + ends = [] + for contig, start, end in regions: + if start is None: + start = 0 + else: + start -= 1 + + if end is None: + end = np.iinfo(np.int64).max + + chromosomes.append(all_contigs.index(contig)) + starts.append(start) + ends.append(end) + + query = pyranges.PyRanges(chromosomes=chromosomes, starts=starts, ends=ends) + + overlap = variants.overlap(query) + return overlap.df["index"].to_numpy() diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index 027186b..833bf3f 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -5,6 +5,7 @@ from typing import MutableMapping, Optional, TextIO, Union import numpy as np +from vcztools.regions import parse_targets, regions_to_selection import zarr from . import _vcztools @@ -80,7 +81,7 @@ def dims(arr): def write_vcf( - vcz, output, *, vcf_header: Optional[str] = None, implementation="numba" + vcz, output, *, vcf_header: Optional[str] = None, variant_targets=None, implementation="numba" ) -> None: """Convert a dataset to a VCF file. @@ -163,11 +164,29 @@ def write_vcf( contigs = root["contig_id"][:].astype("S") filters = root["filter_id"][:].astype("S") + if variant_targets is None: + variant_mask = np.ones(pos.shape[0], dtype=bool) + else: + regions = parse_targets(variant_targets) + variant_selection = regions_to_selection( + root["contig_id"][:].astype("U").tolist(), + root["variant_contig"], + pos[:], + regions, + ) + variant_mask = np.zeros(pos.shape[0], dtype=bool) + variant_mask[variant_selection] = 1 + # Use zarr arrays to get mask chunks aligned with the main data + # for convenience. + z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0]) + for v_chunk in range(pos.cdata_shape[0]): + v_mask_chunk = z_variant_mask.blocks[v_chunk] if implementation == "numba": numba_chunk_to_vcf( root, v_chunk, + v_mask_chunk, header_info_fields, header_format_fields, contigs, @@ -175,25 +194,32 @@ def write_vcf( output, ) else: - c_chunk_to_vcf( - root, - v_chunk, - contigs, - filters, - output, - ) + count = np.sum(v_mask_chunk) + if count > 0: + c_chunk_to_vcf( + root, + v_chunk, + v_mask_chunk, + contigs, + filters, + output, + ) + + +def get_block_selection(zarray, key, mask): + return zarray.blocks[key][mask] -def c_chunk_to_vcf(root, v_chunk, contigs, filters, output): - chrom = contigs[root.variant_contig.blocks[v_chunk]] +def c_chunk_to_vcf(root, v_chunk, v_mask_chunk, contigs, filters, output): + chrom = contigs[get_block_selection(root.variant_contig, v_chunk, v_mask_chunk)] # TODO check we don't truncate silently by doing this - pos = root.variant_position.blocks[v_chunk].astype(np.int32) - id = root.variant_id.blocks[v_chunk].astype("S") - alleles = root.variant_allele.blocks[v_chunk] + pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk).astype(np.int32) + id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S") + alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk) ref = alleles[:, 0].astype("S") alt = alleles[:, 1:].astype("S") - qual = root.variant_quality.blocks[v_chunk] - filter_ = root.variant_filter.blocks[v_chunk] + qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk) + filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk) num_variants = len(pos) if len(id.shape) == 1: @@ -207,21 +233,21 @@ def c_chunk_to_vcf(root, v_chunk, contigs, filters, output): for name, array in root.items(): if name.startswith("call_") and not name.startswith("call_genotype"): vcf_name = name[len("call_") :] - format_fields[vcf_name] = array.blocks[v_chunk] + format_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk) if num_samples is None: num_samples = array.shape[1] elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES: vcf_name = name[len("variant_") :] - info_fields[vcf_name] = array.blocks[v_chunk] + info_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk) gt = None gt_phased = None if "call_genotype" in root: array = root["call_genotype"] - gt = array.blocks[v_chunk] + gt = get_block_selection(array, v_chunk, v_mask_chunk) if "call_genotype_phased" in root: array = root["call_genotype_phased"] - gt_phased = array.blocks[v_chunk] + gt_phased = get_block_selection(array, v_chunk, v_mask_chunk) else: gt_phased = np.zeros_like(gt, dtype=bool) @@ -269,16 +295,16 @@ def c_chunk_to_vcf(root, v_chunk, contigs, filters, output): def numba_chunk_to_vcf( - root, v_chunk, header_info_fields, header_format_fields, contigs, filters, output + root, v_chunk, v_mask_chunk, header_info_fields, header_format_fields, contigs, filters, output ): # fixed fields - chrom = root.variant_contig.blocks[v_chunk] - pos = root.variant_position.blocks[v_chunk] - id = root.variant_id.blocks[v_chunk].astype("S") - alleles = root.variant_allele.blocks[v_chunk].astype("S") - qual = root.variant_quality.blocks[v_chunk] - filter_ = root.variant_filter.blocks[v_chunk] + chrom = get_block_selection(root.variant_contig, v_chunk, v_mask_chunk) + pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk) + id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S") + alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk).astype("S") + qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk) + filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk) n_variants = len(pos) @@ -300,7 +326,7 @@ def numba_chunk_to_vcf( # not the other way around. This is probably not what we want to # do, but keeping it this way to preserve tests initially. continue - values = arr.blocks[v_chunk] + values = get_block_selection(arr, v_chunk, v_mask_chunk) if arr.dtype == bool: info_mask[k] = create_mask(values) info_bufs.append(np.zeros(0, dtype=np.uint8)) @@ -339,7 +365,7 @@ def numba_chunk_to_vcf( var = "call_genotype" if key == "GT" else f"call_{key}" if var not in root: continue - values = root[var].blocks[v_chunk] + values = get_block_selection(root[var], v_chunk, v_mask_chunk) if key == "GT": n_samples = values.shape[1] format_mask[k] = create_mask(values) @@ -367,7 +393,7 @@ def numba_chunk_to_vcf( format_indexes = np.empty((len(format_values), n_samples + 1), dtype=np.int32) if "call_genotype_phased" in root: - call_genotype_phased = root["call_genotype_phased"].blocks[v_chunk][:] + call_genotype_phased = get_block_selection(root["call_genotype_phased"], v_chunk, v_mask_chunk)[:] else: call_genotype_phased = np.full((n_variants, n_samples), False, dtype=bool)