diff --git a/tests/data/vcf/chr22.vcf.gz b/tests/data/vcf/chr22.vcf.gz new file mode 100644 index 0000000..0bd64cb Binary files /dev/null and b/tests/data/vcf/chr22.vcf.gz differ diff --git a/tests/data/vcf/chr22.vcf.gz.csi b/tests/data/vcf/chr22.vcf.gz.csi new file mode 100644 index 0000000..930ed67 Binary files /dev/null and b/tests/data/vcf/chr22.vcf.gz.csi differ diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index 830669c..9ad3f8b 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -34,6 +34,7 @@ def run_vcztools(args: str) -> str: ("args", "vcf_file"), [ ("view --no-version", "sample.vcf.gz"), + ("view --no-version", "chr22.vcf.gz"), ("view --no-version -i 'INFO/DP > 10'", "sample.vcf.gz"), ("view --no-version -i 'FMT/DP >= 5 && FMT/GQ > 10'", "sample.vcf.gz"), ("view --no-version -i 'FMT/DP >= 5 & FMT/GQ>10'", "sample.vcf.gz"), diff --git a/tests/utils.py b/tests/utils.py index e48cc9e..f27d8ab 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -140,8 +140,17 @@ def vcz_path_cache(vcf_path): cache_path.mkdir() cached_vcz_path = (cache_path / vcf_path.name).with_suffix(".vcz") if not cached_vcz_path.exists(): - vcf2zarr.convert( - [vcf_path], cached_vcz_path, worker_processes=0, local_alleles=False - ) + if vcf_path.name.startswith("chr22"): + vcf2zarr.convert( + [vcf_path], + cached_vcz_path, + worker_processes=0, + variants_chunk_size=10, + samples_chunk_size=10, + ) + else: + vcf2zarr.convert( + [vcf_path], cached_vcz_path, worker_processes=0, local_alleles=False + ) create_index(cached_vcz_path) return cached_vcz_path diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index 7b2fb48..0240939 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -1,3 +1,4 @@ +import concurrent.futures import functools import io import re @@ -188,19 +189,28 @@ def write_vcf( if variant_regions is None and variant_targets is None: # no regions or targets selected - for v_chunk in range(pos.cdata_shape[0]): - v_mask_chunk = filter_evaluator(v_chunk) if filter_evaluator else None - c_chunk_to_vcf( - root, - v_chunk, - v_mask_chunk, - samples_selection, - contigs, - filters, - output, - drop_genotypes=drop_genotypes, - no_update=no_update, - ) + with concurrent.futures.ThreadPoolExecutor() as executor: + preceding_future = None + for v_chunk in range(pos.cdata_shape[0]): + v_mask_chunk = ( + filter_evaluator(v_chunk) if filter_evaluator else None + ) + future = executor.submit( + c_chunk_to_vcf, + root, + v_chunk, + v_mask_chunk, + samples_selection, + contigs, + filters, + output, + drop_genotypes=drop_genotypes, + no_update=no_update, + preceding_future=preceding_future, + ) + if preceding_future: + concurrent.futures.wait((preceding_future,)) + preceding_future = future else: contigs_u = root["contig_id"][:].astype("U").tolist() regions = parse_regions(variant_regions, contigs_u) @@ -245,26 +255,32 @@ def write_vcf( # Use zarr arrays to get mask chunks aligned with the main data # for convenience. z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0]) - - for i, v_chunk in enumerate(chunk_indexes): - v_mask_chunk = z_variant_mask.blocks[i] - - if filter_evaluator and np.any(v_mask_chunk): - v_mask_chunk = np.logical_and( - v_mask_chunk, filter_evaluator(v_chunk) - ) - if np.any(v_mask_chunk): - c_chunk_to_vcf( - root, - v_chunk, - v_mask_chunk, - samples_selection, - contigs, - filters, - output, - drop_genotypes=drop_genotypes, - no_update=no_update, - ) + with concurrent.futures.ThreadPoolExecutor() as executor: + preceding_future = None + for i, v_chunk in enumerate(chunk_indexes): + v_mask_chunk = z_variant_mask.blocks[i] + + if filter_evaluator and np.any(v_mask_chunk): + v_mask_chunk = np.logical_and( + v_mask_chunk, filter_evaluator(v_chunk) + ) + if np.any(v_mask_chunk): + future = executor.submit( + c_chunk_to_vcf, + root, + v_chunk, + v_mask_chunk, + samples_selection, + contigs, + filters, + output, + drop_genotypes=drop_genotypes, + no_update=no_update, + preceding_future=preceding_future, + ) + if preceding_future: + concurrent.futures.wait((preceding_future,)) + preceding_future = future def get_vchunk_array(zarray, v_chunk, mask, samples_selection=None): @@ -291,6 +307,7 @@ def c_chunk_to_vcf( *, drop_genotypes, no_update, + preceding_future: concurrent.futures.Future | None = None, ): chrom = contigs[get_vchunk_array(root.variant_contig, v_chunk, v_mask_chunk)] # TODO check we don't truncate silently by doing this @@ -299,67 +316,71 @@ def c_chunk_to_vcf( ) id = get_vchunk_array(root.variant_id, v_chunk, v_mask_chunk).astype("S") alleles = get_vchunk_array(root.variant_allele, v_chunk, v_mask_chunk) - ref = alleles[:, 0].astype("S") - alt = alleles[:, 1:].astype("S") qual = get_vchunk_array(root.variant_quality, v_chunk, v_mask_chunk) filter_ = get_vchunk_array(root.variant_filter, v_chunk, v_mask_chunk) - - num_variants = len(pos) - if len(id.shape) == 1: - id = id.reshape((num_variants, 1)) - - # TODO gathering fields and doing IO will be done separately later so that - # we avoid retrieving stuff we don't need. format_fields = {} info_fields = {} num_samples = len(samples_selection) if samples_selection is not None else None - for name, array in root.items(): - if ( - name.startswith("call_") - and not name.startswith("call_genotype") - and num_samples != 0 - ): - vcf_name = name[len("call_") :] - format_fields[vcf_name] = get_vchunk_array( - array, v_chunk, v_mask_chunk, samples_selection - ) - if num_samples is None: - num_samples = array.shape[1] - elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES: - vcf_name = name[len("variant_") :] - info_fields[vcf_name] = get_vchunk_array(array, v_chunk, v_mask_chunk) - gt = None gt_phased = None if "call_genotype" in root and not drop_genotypes: - array = root["call_genotype"] - if samples_selection is not None and num_samples != 0: - gt = get_vchunk_array(array, v_chunk, v_mask_chunk, samples_selection) + gt = get_vchunk_array( + root["call_genotype"], v_chunk, v_mask_chunk, samples_selection + ) else: - gt = get_vchunk_array(array, v_chunk, v_mask_chunk) + gt = get_vchunk_array(root["call_genotype"], v_chunk, v_mask_chunk) - if not no_update and samples_selection is not None: - # Recompute INFO/AC and INFO/AN - info_fields |= _compute_info_fields(gt, alt) - if num_samples == 0: - gt = None if ( "call_genotype_phased" in root and not drop_genotypes and (samples_selection is None or num_samples > 0) ): - array = root["call_genotype_phased"] gt_phased = get_vchunk_array( - array, v_chunk, v_mask_chunk, samples_selection + root["call_genotype_phased"], + v_chunk, + v_mask_chunk, + samples_selection, ) else: gt_phased = np.zeros_like(gt, dtype=bool) + for name, zarray in root.items(): + if ( + name.startswith("call_") + and not name.startswith("call_genotype") + and num_samples != 0 + ): + vcf_name = name[len("call_") :] + format_fields[vcf_name] = get_vchunk_array( + zarray, v_chunk, v_mask_chunk, samples_selection + ) + if num_samples is None: + num_samples = zarray.shape[1] + elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES: + vcf_name = name[len("variant_") :] + info_fields[vcf_name] = get_vchunk_array(zarray, v_chunk, v_mask_chunk) + + ref = alleles[:, 0].astype("S") + alt = alleles[:, 1:].astype("S") + + if len(id.shape) == 1: + id = id.reshape((-1, 1)) + if ( + not no_update + and samples_selection is not None + and "call_genotype" in root + and not drop_genotypes + ): + # Recompute INFO/AC and INFO/AN + info_fields |= _compute_info_fields(gt, alt) + if num_samples == 0: + gt = None if gt is not None and num_samples is None: num_samples = gt.shape[1] + num_variants = len(pos) encoder = _vcztools.VcfEncoder( num_variants, num_samples if num_samples is not None else 0, @@ -375,21 +396,25 @@ def c_chunk_to_vcf( # print(encoder.arrays) if gt is not None: encoder.add_gt_field(gt, gt_phased) - for name, array in info_fields.items(): + for name, zarray in info_fields.items(): # print(array.dtype.kind) - if array.dtype.kind in ("O", "U"): - array = array.astype("S") - if len(array.shape) == 1: - array = array.reshape((num_variants, 1)) - encoder.add_info_field(name, array) + if zarray.dtype.kind in ("O", "U"): + zarray = zarray.astype("S") + if len(zarray.shape) == 1: + zarray = zarray.reshape((num_variants, 1)) + encoder.add_info_field(name, zarray) if num_samples != 0: - for name, array in format_fields.items(): - if array.dtype.kind in ("O", "U"): - array = array.astype("S") - if len(array.shape) == 2: - array = array.reshape((num_variants, num_samples, 1)) - encoder.add_format_field(name, array) + for name, zarray in format_fields.items(): + if zarray.dtype.kind in ("O", "U"): + zarray = zarray.astype("S") + if len(zarray.shape) == 2: + zarray = zarray.reshape((num_variants, num_samples, 1)) + encoder.add_format_field(name, zarray) + + if preceding_future: + concurrent.futures.wait((preceding_future,)) + # TODO: (1) make a guess at this based on number of fields and samples, # and (2) log a DEBUG message when we have to double. buflen = 1024