Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize I/O operations #97

Merged
merged 14 commits into from
Nov 19, 2024
Merged
Binary file added tests/data/vcf/chr22.vcf.gz
Binary file not shown.
Binary file added tests/data/vcf/chr22.vcf.gz.csi
Binary file not shown.
1 change: 1 addition & 0 deletions tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_vcztools(args: str) -> str:
("args", "vcf_file"),
[
("view --no-version", "sample.vcf.gz"),
("view --no-version", "chr22.vcf.gz"),
("view --no-version -i 'INFO/DP > 10'", "sample.vcf.gz"),
("view --no-version -i 'FMT/DP >= 5 && FMT/GQ > 10'", "sample.vcf.gz"),
("view --no-version -i 'FMT/DP >= 5 & FMT/GQ>10'", "sample.vcf.gz"),
Expand Down
15 changes: 12 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,17 @@ def vcz_path_cache(vcf_path):
cache_path.mkdir()
cached_vcz_path = (cache_path / vcf_path.name).with_suffix(".vcz")
if not cached_vcz_path.exists():
vcf2zarr.convert(
[vcf_path], cached_vcz_path, worker_processes=0, local_alleles=False
)
if vcf_path.name.startswith("chr22"):
vcf2zarr.convert(
[vcf_path],
cached_vcz_path,
worker_processes=0,
variants_chunk_size=10,
samples_chunk_size=10,
)
else:
vcf2zarr.convert(
[vcf_path], cached_vcz_path, worker_processes=0, local_alleles=False
)
create_index(cached_vcz_path)
return cached_vcz_path
187 changes: 106 additions & 81 deletions vcztools/vcf_writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import functools
import io
import re
Expand Down Expand Up @@ -188,19 +189,28 @@ def write_vcf(

if variant_regions is None and variant_targets is None:
# no regions or targets selected
for v_chunk in range(pos.cdata_shape[0]):
v_mask_chunk = filter_evaluator(v_chunk) if filter_evaluator else None
c_chunk_to_vcf(
root,
v_chunk,
v_mask_chunk,
samples_selection,
contigs,
filters,
output,
drop_genotypes=drop_genotypes,
no_update=no_update,
)
with concurrent.futures.ThreadPoolExecutor() as executor:
preceding_future = None
for v_chunk in range(pos.cdata_shape[0]):
v_mask_chunk = (
filter_evaluator(v_chunk) if filter_evaluator else None
)
future = executor.submit(
c_chunk_to_vcf,
root,
v_chunk,
v_mask_chunk,
samples_selection,
contigs,
filters,
output,
drop_genotypes=drop_genotypes,
no_update=no_update,
preceding_future=preceding_future,
)
if preceding_future:
concurrent.futures.wait((preceding_future,))
preceding_future = future
else:
contigs_u = root["contig_id"][:].astype("U").tolist()
regions = parse_regions(variant_regions, contigs_u)
Expand Down Expand Up @@ -245,26 +255,32 @@ def write_vcf(
# Use zarr arrays to get mask chunks aligned with the main data
# for convenience.
z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0])

for i, v_chunk in enumerate(chunk_indexes):
v_mask_chunk = z_variant_mask.blocks[i]

if filter_evaluator and np.any(v_mask_chunk):
v_mask_chunk = np.logical_and(
v_mask_chunk, filter_evaluator(v_chunk)
)
if np.any(v_mask_chunk):
c_chunk_to_vcf(
root,
v_chunk,
v_mask_chunk,
samples_selection,
contigs,
filters,
output,
drop_genotypes=drop_genotypes,
no_update=no_update,
)
with concurrent.futures.ThreadPoolExecutor() as executor:
preceding_future = None
for i, v_chunk in enumerate(chunk_indexes):
v_mask_chunk = z_variant_mask.blocks[i]

if filter_evaluator and np.any(v_mask_chunk):
v_mask_chunk = np.logical_and(
v_mask_chunk, filter_evaluator(v_chunk)
)
if np.any(v_mask_chunk):
future = executor.submit(
c_chunk_to_vcf,
root,
v_chunk,
v_mask_chunk,
samples_selection,
contigs,
filters,
output,
drop_genotypes=drop_genotypes,
no_update=no_update,
preceding_future=preceding_future,
)
if preceding_future:
concurrent.futures.wait((preceding_future,))
preceding_future = future


def get_vchunk_array(zarray, v_chunk, mask, samples_selection=None):
Expand All @@ -291,6 +307,7 @@ def c_chunk_to_vcf(
*,
drop_genotypes,
no_update,
preceding_future: concurrent.futures.Future | None = None,
):
chrom = contigs[get_vchunk_array(root.variant_contig, v_chunk, v_mask_chunk)]
# TODO check we don't truncate silently by doing this
Expand All @@ -299,67 +316,71 @@ def c_chunk_to_vcf(
)
id = get_vchunk_array(root.variant_id, v_chunk, v_mask_chunk).astype("S")
alleles = get_vchunk_array(root.variant_allele, v_chunk, v_mask_chunk)
ref = alleles[:, 0].astype("S")
alt = alleles[:, 1:].astype("S")
qual = get_vchunk_array(root.variant_quality, v_chunk, v_mask_chunk)
filter_ = get_vchunk_array(root.variant_filter, v_chunk, v_mask_chunk)

num_variants = len(pos)
if len(id.shape) == 1:
id = id.reshape((num_variants, 1))

# TODO gathering fields and doing IO will be done separately later so that
# we avoid retrieving stuff we don't need.
format_fields = {}
info_fields = {}
num_samples = len(samples_selection) if samples_selection is not None else None
for name, array in root.items():
if (
name.startswith("call_")
and not name.startswith("call_genotype")
and num_samples != 0
):
vcf_name = name[len("call_") :]
format_fields[vcf_name] = get_vchunk_array(
array, v_chunk, v_mask_chunk, samples_selection
)
if num_samples is None:
num_samples = array.shape[1]
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
vcf_name = name[len("variant_") :]
info_fields[vcf_name] = get_vchunk_array(array, v_chunk, v_mask_chunk)

gt = None
gt_phased = None

if "call_genotype" in root and not drop_genotypes:
array = root["call_genotype"]

if samples_selection is not None and num_samples != 0:
gt = get_vchunk_array(array, v_chunk, v_mask_chunk, samples_selection)
gt = get_vchunk_array(
root["call_genotype"], v_chunk, v_mask_chunk, samples_selection
)
else:
gt = get_vchunk_array(array, v_chunk, v_mask_chunk)
gt = get_vchunk_array(root["call_genotype"], v_chunk, v_mask_chunk)

if not no_update and samples_selection is not None:
# Recompute INFO/AC and INFO/AN
info_fields |= _compute_info_fields(gt, alt)
if num_samples == 0:
gt = None
if (
"call_genotype_phased" in root
and not drop_genotypes
and (samples_selection is None or num_samples > 0)
):
array = root["call_genotype_phased"]
gt_phased = get_vchunk_array(
array, v_chunk, v_mask_chunk, samples_selection
root["call_genotype_phased"],
v_chunk,
v_mask_chunk,
samples_selection,
)
else:
gt_phased = np.zeros_like(gt, dtype=bool)

for name, zarray in root.items():
if (
name.startswith("call_")
and not name.startswith("call_genotype")
and num_samples != 0
):
vcf_name = name[len("call_") :]
format_fields[vcf_name] = get_vchunk_array(
zarray, v_chunk, v_mask_chunk, samples_selection
)
if num_samples is None:
num_samples = zarray.shape[1]
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
vcf_name = name[len("variant_") :]
info_fields[vcf_name] = get_vchunk_array(zarray, v_chunk, v_mask_chunk)

ref = alleles[:, 0].astype("S")
alt = alleles[:, 1:].astype("S")

if len(id.shape) == 1:
id = id.reshape((-1, 1))
if (
not no_update
and samples_selection is not None
and "call_genotype" in root
and not drop_genotypes
):
# Recompute INFO/AC and INFO/AN
info_fields |= _compute_info_fields(gt, alt)
if num_samples == 0:
gt = None
if gt is not None and num_samples is None:
num_samples = gt.shape[1]

num_variants = len(pos)
encoder = _vcztools.VcfEncoder(
num_variants,
num_samples if num_samples is not None else 0,
Expand All @@ -375,21 +396,25 @@ def c_chunk_to_vcf(
# print(encoder.arrays)
if gt is not None:
encoder.add_gt_field(gt, gt_phased)
for name, array in info_fields.items():
for name, zarray in info_fields.items():
# print(array.dtype.kind)
if array.dtype.kind in ("O", "U"):
array = array.astype("S")
if len(array.shape) == 1:
array = array.reshape((num_variants, 1))
encoder.add_info_field(name, array)
if zarray.dtype.kind in ("O", "U"):
zarray = zarray.astype("S")
if len(zarray.shape) == 1:
zarray = zarray.reshape((num_variants, 1))
encoder.add_info_field(name, zarray)

if num_samples != 0:
for name, array in format_fields.items():
if array.dtype.kind in ("O", "U"):
array = array.astype("S")
if len(array.shape) == 2:
array = array.reshape((num_variants, num_samples, 1))
encoder.add_format_field(name, array)
for name, zarray in format_fields.items():
if zarray.dtype.kind in ("O", "U"):
zarray = zarray.astype("S")
if len(zarray.shape) == 2:
zarray = zarray.reshape((num_variants, num_samples, 1))
encoder.add_format_field(name, zarray)

if preceding_future:
concurrent.futures.wait((preceding_future,))

# TODO: (1) make a guess at this based on number of fields and samples,
# and (2) log a DEBUG message when we have to double.
buflen = 1024
Expand Down
Loading