diff --git a/tests/test_bcftools_validation.py b/tests/test_bcftools_validation.py index d5dd564..2d63962 100644 --- a/tests/test_bcftools_validation.py +++ b/tests/test_bcftools_validation.py @@ -61,7 +61,11 @@ def run_vcztools(args: str) -> str: ("view --no-version -s NA00001", "sample.vcf.gz"), ("view --no-version -s NA00001,NA00003", "sample.vcf.gz"), ("view --no-version -s HG00096", "1kg_2020_chrM.vcf.gz"), - ("view --no-version -s '' --force-samples", "sample.vcf.gz") + ("view --no-version -s '' --force-samples", "sample.vcf.gz"), + ("view --no-version -s ^NA00001", "sample.vcf.gz"), + ("view --no-version -s ^NA00003,NA00002", "sample.vcf.gz"), + ("view --no-version -s ^NA00003,NA00002,NA00003", "sample.vcf.gz"), + ("view --no-version -S ^tests/data/txt/samples.txt", "sample.vcf.gz"), ] ) # fmt: on diff --git a/tests/test_cli.py b/tests/test_cli.py index d0f1c38..a955e5d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,21 +7,32 @@ from tests.utils import vcz_path_cache -def test_version_header(): +@pytest.fixture() +def vcz_path(): vcf_path = pathlib.Path("tests/data/vcf/sample.vcf.gz") - vcz_path = vcz_path_cache(vcf_path) + return vcz_path_cache(vcf_path) + +def test_version_header(vcz_path): output = run_vcztools(f"view {vcz_path}") assert output.find("##vcztools_viewCommand=") >= 0 assert output.find("Date=") >= 0 -def test_view_bad_output(tmp_path): - vcf_path = pathlib.Path("tests/data/vcf/sample.vcf.gz") - vcz_path = vcz_path_cache(vcf_path) +def test_view_bad_output(tmp_path, vcz_path): bad_output = tmp_path / "output.vcf.gz" with pytest.raises( ValueError, match=re.escape("Output file extension must be .vcf, got: .gz") ): run_vcztools(f"view --no-version {vcz_path} -o {bad_output}") + + +def test_excluding_and_including_samples(vcz_path): + samples_file_path = pathlib.Path("tests/data/txt/samples.txt") + error_message = re.escape("vcztools does not support combining -s and -S") + + with pytest.raises(AssertionError, match=error_message): + run_vcztools(f"view {vcz_path} -s NA00001 -S ^{samples_file_path}") + with pytest.raises(AssertionError, match=error_message): + run_vcztools(f"view {vcz_path} -s ^NA00001 -S {samples_file_path}") diff --git a/tests/test_vcf_writer.py b/tests/test_vcf_writer.py index 0df9c1c..57f7845 100644 --- a/tests/test_vcf_writer.py +++ b/tests/test_vcf_writer.py @@ -145,14 +145,17 @@ def test_write_vcf__regions(tmp_path, regions, targets, assert variant.POS == pos @pytest.mark.parametrize( - ("samples", "expected_genotypes"), + ("samples", "expected_samples", "expected_genotypes"), [ - ("NA00001", [[0, 0, True]]), - ("NA00001,NA00003", [[0, 0, True], [0, 1, False]]), - ("NA00003,NA00001", [[0, 1, False], [0, 0, True]]), + ("NA00001", ["NA00001"], [[0, 0, True]]), + ("NA00001,NA00003", ["NA00001", "NA00003"], [[0, 0, True], [0, 1, False]]), + ("NA00003,NA00001", ["NA00003", "NA00001"], [[0, 1, False], [0, 0, True]]), + ("^NA00002", ["NA00001", "NA00003"], [[0, 0, True], [0, 1, False]]), + ("^NA00003,NA00002", ["NA00001"], [[0, 0, True]]), + ("^NA00003,NA00002,NA00003", ["NA00001"], [[0, 0, True]]), ] ) -def test_write_vcf__samples(tmp_path, samples, expected_genotypes): +def test_write_vcf__samples(tmp_path, samples, expected_samples, expected_genotypes): original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" vcz = vcz_path_cache(original) output = tmp_path.joinpath("output.vcf") @@ -161,7 +164,7 @@ def test_write_vcf__samples(tmp_path, samples, expected_genotypes): v = VCF(output) - assert v.samples == samples.split(",") + assert v.samples == expected_samples variant = next(v) diff --git a/vcztools/cli.py b/vcztools/cli.py index 049de22..99e5a30 100644 --- a/vcztools/cli.py +++ b/vcztools/cli.py @@ -149,8 +149,15 @@ def view( raise ValueError(f"Output file extension must be .vcf, got: .{split[-1]}") if samples_file: + assert not samples, "vcztools does not support combining -s and -S" + + samples = "" + exclude_samples_file = samples_file.startswith("^") + samples_file = samples_file.lstrip("^") + with open(samples_file) as file: - samples = samples or "" + if exclude_samples_file: + samples = "^" + samples samples += ",".join(line.strip() for line in file.readlines()) # TODO: use no_update when fixing https://github.com/sgkit-dev/vcztools/issues/75 diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index d383210..ce6d50d 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -153,10 +153,18 @@ def write_vcf( samples_selection = None else: all_samples = root["sample_id"][:] + exclude_samples = samples.startswith("^") + samples = samples.lstrip("^") sample_ids = np.array(samples.split(",")) if np.all(sample_ids == np.array("")): sample_ids = np.empty((0,)) + samples_selection = search(all_samples, sample_ids) + if exclude_samples: + samples_selection = np.setdiff1d( + np.arange(all_samples.size), samples_selection + ) + sample_ids = all_samples[samples_selection] if not no_header and vcf_header is None: if "vcf_header" in root.attrs: