diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index 3a5cb88..6cfb9ee 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -74,6 +74,8 @@ def test_vcf_output(tmp_path, args, vcf_file): [ ("index -n", "sample.vcf.gz"), ("index --nrecords", "1kg_2020_chrM.vcf.gz"), + ("index -s", "sample.vcf.gz"), + ("index --stats", "1kg_2020_chrM.vcf.gz"), ("query -l", "sample.vcf.gz"), ("query --list-samples", "1kg_2020_chrM.vcf.gz"), ], diff --git a/tests/test_stats.py b/tests/test_stats.py index e4787bb..e705d83 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -1,7 +1,7 @@ import pathlib from io import StringIO -from vcztools.stats import nrecords +from vcztools.stats import nrecords, stats from .utils import vcz_path_cache @@ -13,3 +13,19 @@ def test_nrecords(): output_str = StringIO() nrecords(vcz, output_str) assert output_str.getvalue() == "9\n" + + +def test_stats(): + original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" + vcz = vcz_path_cache(original) + + output_str = StringIO() + stats(vcz, output_str) + + assert ( + output_str.getvalue() + == """19 . 2 +20 . 6 +X . 1 +""" + ) diff --git a/vcztools/cli.py b/vcztools/cli.py index 873880a..10d08aa 100644 --- a/vcztools/cli.py +++ b/vcztools/cli.py @@ -3,7 +3,8 @@ import click from . import query as query_module -from . import regions, stats, vcf_writer +from . import regions, vcf_writer +from . import stats as stats_module class NaturalOrderGroup(click.Group): @@ -23,9 +24,17 @@ def list_commands(self, ctx): is_flag=True, help="Print the number of records (variants).", ) -def index(path, nrecords): +@click.option( + "-s", + "--stats", + is_flag=True, + help="Print per contig stats.", +) +def index(path, nrecords, stats): if nrecords: - stats.nrecords(path, sys.stdout) + stats_module.nrecords(path, sys.stdout) + elif stats: + stats_module.stats(path, sys.stdout) else: regions.create_index(path) diff --git a/vcztools/stats.py b/vcztools/stats.py index 460802e..25ee6bc 100644 --- a/vcztools/stats.py +++ b/vcztools/stats.py @@ -1,3 +1,4 @@ +import numpy as np import zarr from vcztools.utils import open_file_like @@ -9,3 +10,29 @@ def nrecords(vcz, output): with open_file_like(output) as output: num_variants = root["variant_position"].shape[0] print(num_variants, file=output) + + +def stats(vcz, output): + root = zarr.open(vcz, mode="r") + + with open_file_like(output) as output: + contigs = root["contig_id"][:].astype("U").tolist() + if "contig_length" in root: + contig_lengths = root["contig_length"][:] + else: + contig_lengths = ["."] * len(contigs) + + region_index = root["region_index"][:] + + contig_indexes = region_index[:, 1] + num_records = region_index[:, 5] + + num_records_per_contig = np.bincount( + contig_indexes, weights=num_records + ).astype(np.int64) + + for contig, contig_length, nr in zip( + contigs, contig_lengths, num_records_per_contig + ): + if nr > 0: + print(f"{contig}\t{contig_length}\t{nr}", file=output)