From 589194e5491a743a2b0de143859f37f5061405a2 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Tue, 3 Dec 2024 12:46:06 +0100 Subject: [PATCH] Make taxonomy more consistently used (#374) More consistently differentiate between unrefined and refined taxonomy. * Taxometer requires an unrefined, and errors on a refined one * Recluster DBScan can take either, but will warn if passed a refined one, and not `--no_predictor`. If refinement is needed and the requisite comp and ab are not passed, error * TaxVamb can take either, but warns like DBScan. Does not do addtional check for comp and ab, since this is always required for TaxVamb. Also minor fixes in logging Also some changes to the DBScan algorithm - the old algo has not yet been validated, so this is not a high-impact change. --- .github/workflows/cli_vamb.yml | 5 +- vamb/__main__.py | 207 ++++++++++++++++++++++----------- vamb/reclustering.py | 82 ++++++------- vamb/taxonomy.py | 61 ++++++++-- 4 files changed, 224 insertions(+), 131 deletions(-) diff --git a/.github/workflows/cli_vamb.yml b/.github/workflows/cli_vamb.yml index 9018429d..80e07cde 100644 --- a/.github/workflows/cli_vamb.yml +++ b/.github/workflows/cli_vamb.yml @@ -28,7 +28,7 @@ jobs: cache-dependency-path: '**/pyproject.toml' - name: Download fixtures run: | - wget https://www.dropbox.com/scl/fi/xzc0tro7oe6tfm3igygpj/ci_data.zip\?rlkey\=xuv6b5eoynfryp4fba1kfp5jm\&st\=rjb1xccw\&dl\=0 -O ci_data.zip + wget https://www.dropbox.com/scl/fi/10tdf0w0kf70pf46hy8ks/ci_data.zip\?rlkey\=smlcinkesuwiw557zulgbb59l\&st\=hhokiqma\&dl\=0 -O ci_data.zip unzip -o ci_data.zip - name: Install dependencies run: | @@ -56,9 +56,6 @@ jobs: vamb taxometer --outdir outdir_taxometer --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pt 10 ls -la outdir_taxometer cat outdir_taxometer/log.txt - vamb taxometer --outdir outdir_taxometer_pred --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy outdir_taxometer/results_taxometer.tsv -pe 10 -pt 10 - ls -la outdir_taxometer_pred - cat outdir_taxometer/log.txt - name: Run k-means reclustering run: | vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --markers markers_mock.npz --algorithm kmeans --minfasta 200000 diff --git a/vamb/__main__.py b/vamb/__main__.py index 12c9a3e7..2d2f9d34 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -373,30 +373,61 @@ def __init__( self.refcheck = refcheck -class TaxonomyPath: - @classmethod - def from_args(cls, args: argparse.Namespace): - if args.taxonomy is None: - raise argparse.ArgumentTypeError( - "Cannot load taxonomy without specifying --taxonomy" - ) - return cls(typeasserted(args.taxonomy, Path)) +class TaxonomyBase: + __slots__ = ["path"] def __init__(self, path: Path): - self.path = check_existing_file(path) + self.path = path - def get_tax_path(self) -> Path: - return self.path + +class RefinedTaxonomy(TaxonomyBase): + pass + + +class UnrefinedTaxonomy(TaxonomyBase): + pass + + +def get_taxonomy(args: argparse.Namespace) -> Union[RefinedTaxonomy, UnrefinedTaxonomy]: + path = args.taxonomy + if path is None: + raise ValueError( + "Cannot load taxonomy for Taxometer without specifying --taxonomy" + ) + with open(check_existing_file(path)) as file: + try: + header = next(file).rstrip("\r\n") + except StopIteration: + header = None + + if header is None: + raise ValueError(f'Empty taxonomy path at "{path}"') + elif header == vamb.taxonomy.TAXONOMY_HEADER: + return UnrefinedTaxonomy(path) + elif header == vamb.taxonomy.PREDICTED_TAXONOMY_HEADER: + return RefinedTaxonomy(path) + else: + raise ValueError( + f'ERROR: When reading taxonomy file at "{path}", ' + f"the first line was not either {repr(vamb.taxonomy.TAXONOMY_HEADER)} " + f"or {repr(vamb.taxonomy.PREDICTED_TAXONOMY_HEADER)}'" + ) class TaxometerOptions: @classmethod def from_args(cls, args: argparse.Namespace): + tax = get_taxonomy(args) + if isinstance(tax, RefinedTaxonomy): + raise ValueError( + f'Attempted to run Taxometer to refine taxonomy at "{args.taxonomy}", ' + "but this file appears to already be an output of Taxometer" + ) return cls( GeneralOptions.from_args(args), CompositionOptions.from_args(args), AbundanceOptions.from_args(args), - TaxonomyPath.from_args(args), + tax, BasicTrainingOptions.from_args_taxometer(args), typeasserted(args.pred_softmax_threshold, float), args.ploss, @@ -407,7 +438,7 @@ def __init__( general: GeneralOptions, composition: CompositionOptions, abundance: AbundanceOptions, - taxonomy: TaxonomyPath, + taxonomy: UnrefinedTaxonomy, basic: BasicTrainingOptions, softmax_threshold: float, ploss: Union[ @@ -432,38 +463,6 @@ def __init__( self.basic = basic -class RefinableTaxonomyOptions: - __slots__ = ["path_or_tax_options"] - - @classmethod - def from_args(cls, args: argparse.Namespace): - predict = not typeasserted(args.no_predictor, bool) - - # TaxometerOptions have more options, but only the composition and the abundance - # can be omitted, so we only check those here. - if predict: - if not CompositionOptions.are_args_present( - args - ) or not AbundanceOptions.are_args_present(args): - raise ValueError( - "If `--no_predictor` is not passed, Taxometer is run to refine taxonomy, " - "and this requires composition input and abundance input to be passed in" - ) - return cls(TaxometerOptions.from_args(args)) - else: - return cls(TaxonomyPath.from_args(args)) - - def __init__(self, path_or_tax_options: Union[TaxonomyPath, TaxometerOptions]): - self.path_or_tax_options = path_or_tax_options - - def get_tax_path(self) -> Path: - p = self.path_or_tax_options - if isinstance(p, TaxonomyPath): - return p.get_tax_path() - else: - return p.taxonomy.get_tax_path() - - class MarkerPath: def __init__(self, path: Path): self.path = check_existing_file(path) @@ -504,8 +503,50 @@ def __init__(self, clusters: Path): class DBScanOptions: - def __init__(self, taxonomy_options: RefinableTaxonomyOptions, n_processes: int): - self.taxonomy_options = taxonomy_options + @classmethod + def from_args(cls, args: argparse.Namespace, n_threads: int): + tax = get_taxonomy(args) + predict = not typeasserted(args.no_predictor, bool) + if predict: + if isinstance(tax, RefinedTaxonomy): + logger.warning( + "Flag --no_predictor not set, but the taxonomy passed in " + "on --taxonomy is already refined. Skipped refinement." + ) + return cls(tax, n_threads) + else: + if not ( + CompositionOptions.are_args_present(args) + and AbundanceOptions.are_args_present(args) + ): + raise ValueError( + "Flag --no_predictor is not set, but abundance " + "or composition has not been passed in, so there is no information " + "to refine taxonomy with" + ) + tax_options = TaxometerOptions( + GeneralOptions.from_args(args), + CompositionOptions.from_args(args), + AbundanceOptions.from_args(args), + tax, + BasicTrainingOptions.from_args_taxometer(args), + typeasserted(args.pred_softmax_threshold, float), + args.ploss, + ) + return cls(tax_options, n_threads) + else: + return cls(tax, n_threads) + + def __init__( + self, + ops: Union[ + UnrefinedTaxonomy, + RefinedTaxonomy, + TaxometerOptions, + ], + n_processes: int, + ): + self.taxonomy = ops self.n_processes = n_processes @@ -714,15 +755,35 @@ def from_args(cls, args: argparse.Namespace): common = BinnerCommonOptions.from_args(args) basic = BasicTrainingOptions.from_args_vae(args) vae = VAEOptions.from_args(basic, args) - taxonomy = RefinableTaxonomyOptions.from_args(args) - return cls(common, vae, taxonomy) + taxonomy = get_taxonomy(args) + predict = not typeasserted(args.no_predictor, bool) + if predict: + if isinstance(taxonomy, RefinedTaxonomy): + logger.warning( + "Flag --no_predictor not set, but the taxonomy passed in " + "on --taxonomy is already refined. Skipped refinement." + ) + tax = taxonomy + else: + tax = TaxometerOptions( + common.general, + common.comp, + common.abundance, + taxonomy, + basic, + typeasserted(args.pred_softmax_threshold, float), + args.ploss, + ) + else: + tax = taxonomy + return cls(common, vae, tax) # The VAEVAE models share the same settings as the VAE model, so we just use VAEOptions def __init__( self, common: BinnerCommonOptions, vae: VAEOptions, - taxonomy: RefinableTaxonomyOptions, + taxonomy: Union[RefinedTaxonomy, TaxometerOptions, UnrefinedTaxonomy], ): self.common = common self.vae = vae @@ -768,15 +829,15 @@ def from_args(cls, args: argparse.Namespace): ) algorithm = KmeansOptions(clusters) elif args.algorithm == "dbscan": - tax = RefinableTaxonomyOptions.from_args(args) - algorithm = DBScanOptions(tax, general.n_threads) + algorithm = DBScanOptions.from_args(args, general.n_threads) else: assert False # no more algorithms + # Avoid loading composition again if already loaded in DBScanOptions if isinstance(algorithm, DBScanOptions) and isinstance( - algorithm.taxonomy_options.path_or_tax_options, TaxometerOptions + algorithm.taxonomy, tuple ): - comp = algorithm.taxonomy_options.path_or_tax_options.composition + comp = algorithm.taxonomy[1] else: comp = CompositionOptions.from_args(args) @@ -1456,24 +1517,26 @@ def run_vaevae(opt: BinTaxVambOptions): composition.metadata.lengths, composition.metadata.identifiers, ) - if isinstance(opt.taxonomy.path_or_tax_options, TaxometerOptions): - logger.info("Predicting missing values from taxonomy") + if isinstance(opt.taxonomy, TaxometerOptions): predicted_contig_taxonomies = predict_taxonomy( comp_metadata=composition.metadata, abundance=abundance, tnfs=tnfs, lengths=lengths, out_dir=opt.common.general.out_dir, - taxonomy_options=opt.taxonomy.path_or_tax_options, + taxonomy_options=opt.taxonomy, cuda=opt.common.general.cuda, ) contig_taxonomies = predicted_contig_taxonomies.to_taxonomy() + elif isinstance(opt.taxonomy, RefinedTaxonomy): + logger.info("Loading already-refined taxonomy from file") + contig_taxonomies = vamb.taxonomy.Taxonomy.from_refined_file( + opt.taxonomy.path, composition.metadata, False + ) else: - logger.info("Not predicting the taxonomy") + logger.info("Loading unrefined taxonomy from file") contig_taxonomies = vamb.taxonomy.Taxonomy.from_file( - opt.taxonomy.path_or_tax_options.path, - composition.metadata, - False, + opt.taxonomy.path, composition.metadata, False ) nodes, ind_nodes, table_parent = vamb.taxvamb_encode.make_graph( @@ -1569,9 +1632,6 @@ def run_vaevae(opt: BinTaxVambOptions): ) -# TODO: The whole data flow around predict_taxonomy needs to change. -# Ideally, we should have a "get taxonomy" function that loads, possibly refines, -# and then returns the taxonomy object. def run_reclustering(opt: ReclusteringOptions): composition = calc_tnf( opt.composition, @@ -1587,8 +1647,8 @@ def run_reclustering(opt: ReclusteringOptions): if isinstance(alg, DBScanOptions): # If we should refine or not. - if isinstance(alg.taxonomy_options.path_or_tax_options, TaxometerOptions): - taxopt = alg.taxonomy_options.path_or_tax_options + if isinstance(alg.taxonomy, TaxometerOptions): + taxopt = alg.taxonomy abundance = calc_abundance( taxopt.abundance, taxopt.general.out_dir, @@ -1607,11 +1667,16 @@ def run_reclustering(opt: ReclusteringOptions): ) taxonomy = predicted_tax.to_taxonomy() else: - tax_path = alg.taxonomy_options.path_or_tax_options.path - logger.info(f'Loading taxonomy from file "{tax_path}"') - taxonomy = vamb.taxonomy.Taxonomy.from_file( - tax_path, composition.metadata, True - ) + logger.info(f'Loading taxonomy from file "{alg.taxonomy.path}"') + if isinstance(alg.taxonomy, UnrefinedTaxonomy): + taxonomy = vamb.taxonomy.Taxonomy.from_file( + alg.taxonomy.path, composition.metadata, True + ) + else: + taxonomy = vamb.taxonomy.Taxonomy.from_refined_file( + alg.taxonomy.path, composition.metadata, True + ) + instantiated_alg = vamb.reclustering.DBScanAlgorithm( composition.metadata, taxonomy, opt.general.n_threads ) diff --git a/vamb/reclustering.py b/vamb/reclustering.py index 0c4cfcc8..130ed29e 100644 --- a/vamb/reclustering.py +++ b/vamb/reclustering.py @@ -236,12 +236,6 @@ def get_completeness_contamination(counts: np.ndarray) -> tuple[float, float]: return (completeness, contamination) -# An arbitrary score of a bin, where higher numbers is better. -# completeness - 5 * contamination is used by the CheckM group as a heuristic. -def score_from_comp_cont(comp_cont: tuple[float, float]) -> float: - return comp_cont[0] - 5 * comp_cont[1] - - def recluster_dbscan( taxonomy: Taxonomy, latent: np.ndarray, @@ -251,15 +245,31 @@ def recluster_dbscan( ) -> list[set[ContigId]]: # Since DBScan is computationally expensive, and scales poorly with the number # of contigs, we use taxonomy to only cluster within each genus - result: list[set[ContigId]] = [] - for indices in group_indices_by_genus(taxonomy): - genus_latent = latent[indices] - genus_clusters = dbscan_genus( - genus_latent, indices, contiglengths[indices], markers, num_processes - ) - result.extend(genus_clusters) + n_worse_in_row = 0 + genera_indices = group_indices_by_genus(taxonomy) + best_score = 0 + best_bins: list[set[ContigId]] = [] + for eps in EPS_VALUES: + bins: list[set[ContigId]] = [] + for indices in genera_indices: + genus_clusters = dbscan_genus( + latent[indices], indices, contiglengths[indices], num_processes, eps + ) + bins.extend(genus_clusters) - return result + score = count_good_genomes(bins, markers) + if best_score == 0 or score > best_score: + best_bins = bins + best_score = score + + if score >= best_score: + n_worse_in_row = 0 + else: + n_worse_in_row += 1 + if n_worse_in_row > 2: + break + + return best_bins # DBScan within the subset of contigs that are annotated with a single genus @@ -267,8 +277,8 @@ def dbscan_genus( latent_of_genus: np.ndarray, original_indices: np.ndarray, contiglengths_of_genus: np.ndarray, - markers: Markers, num_processes: int, + eps: float, ) -> list[set[ContigId]]: assert len(latent_of_genus) == len(original_indices) == len(contiglengths_of_genus) # Precompute distance matrix. This is O(N^2), but DBScan is even worse, @@ -278,41 +288,21 @@ def dbscan_genus( distance_matrix = pairwise_distances( latent_of_genus, latent_of_genus, metric="cosine" ) - best_bins: tuple[int, dict[int, set[ContigId]]] = (-1, dict()) # The DBScan approach works by blindly clustering with different eps values # (a critical parameter for DBscan), and then using SCGs to select the best # subset of clusters. # It's ugly and wasteful, but it does work. - n_worse_in_row = 0 - for eps in EPS_VALUES: - print(f"Running eps: {eps}") - dbscan = DBSCAN( - eps=eps, - min_samples=5, - n_jobs=num_processes, - metric="precomputed", - ) - dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus) - these_bins: dict[int, set[ContigId]] = defaultdict(set) - for original_index, bin_index in zip(original_indices, dbscan.labels_): - these_bins[bin_index].add(ContigId(original_index)) - score = count_good_genomes(these_bins.values(), markers) - print(f"Score: {score}, best: {best_bins[0]}, worse: {n_worse_in_row}") - if score > best_bins[0]: - best_bins = (score, these_bins) - n_worse_in_row = 0 - elif score < best_bins[0]: - # This is an elif statement and not an if statement because we don't - # want e.g. a series of [1,1,1,1,1] genomes to be considered a performance - # degradation leading to early exit, when in fact, it probably means - # that eps is too low and should be increased. - n_worse_in_row += 1 - # If we see 3 worse clusterings in a row, we exit early. - - if n_worse_in_row > 2: - break - - return list(best_bins[1].values()) + dbscan = DBSCAN( + eps=eps, + min_samples=5, + n_jobs=num_processes, + metric="precomputed", + ) + dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus) + bins: dict[int, set[ContigId]] = defaultdict(set) + for original_index, bin_index in zip(original_indices, dbscan.labels_): + bins[bin_index].add(ContigId(original_index)) + return list(bins.values()) def count_good_genomes(binning: Iterable[Iterable[ContigId]], markers: Markers) -> int: diff --git a/vamb/taxonomy.py b/vamb/taxonomy.py index 7e88919d..7a9e1d63 100644 --- a/vamb/taxonomy.py +++ b/vamb/taxonomy.py @@ -3,6 +3,9 @@ from vamb.parsecontigs import CompositionMetaData import numpy as np +TAXONOMY_HEADER = "contigs\tpredictions" +PREDICTED_TAXONOMY_HEADER = "contigs\tpredictions\tscores" + class ContigTaxonomy: """ @@ -56,6 +59,14 @@ def from_file( observed = cls.parse_tax_file(tax_file, is_canonical) return cls.from_observed(observed, metadata, is_canonical) + @classmethod + def from_refined_file( + cls, tax_file: Path, metadata: CompositionMetaData, is_canonical: bool + ): + observed = PredictedTaxonomy.parse_tax_file(tax_file, is_canonical) + observed = [(name, tax.contig_taxonomy) for (name, tax) in observed] + return cls.from_observed(observed, metadata, is_canonical) + @classmethod def from_observed( cls, @@ -100,13 +111,12 @@ def parse_tax_file( ) -> list[tuple[str, ContigTaxonomy]]: with open(path) as file: result: list[tuple[str, ContigTaxonomy]] = [] - lines = filter(None, map(str.rstrip, file)) - header = next(lines, None) - if header is None or not header.startswith("contigs\tpredictions"): + header = next(file, None) + if header is None or not header.startswith(TAXONOMY_HEADER): raise ValueError( - 'In taxonomy file, expected header to begin with "contigs\\tpredictions"' + f"In taxonomy file, expected header to begin with {repr(TAXONOMY_HEADER)}" ) - for line in lines: + for line in file: (contigname, taxonomy, *_) = line.split("\t") result.append( ( @@ -124,7 +134,10 @@ class PredictedContigTaxonomy: def __init__(self, tax: ContigTaxonomy, probs: np.ndarray): if len(probs) != len(tax.ranks): raise ValueError("The length of probs must equal that of ranks") - self.tax = tax + # Due to floating point errors, the probabilities may be slightly outside of 0 or 1. + # We could perhaps validate the values, but that's not likely to be necessary. + np.clip(probs, a_min=0.0, a_max=1.0, out=probs) + self.contig_taxonomy = tax self.probs = probs @@ -147,28 +160,56 @@ def __init__( self.is_canonical = is_canonical def to_taxonomy(self) -> Taxonomy: - lst: list[Optional[ContigTaxonomy]] = [p.tax for p in self.contig_taxonomies] + lst: list[Optional[ContigTaxonomy]] = [ + p.contig_taxonomy for p in self.contig_taxonomies + ] return Taxonomy(lst, self.refhash, self.is_canonical) @property def nseqs(self) -> int: return len(self.contig_taxonomies) + @staticmethod + def parse_tax_file( + path: Path, force_canonical: bool + ) -> list[tuple[str, PredictedContigTaxonomy]]: + with open(path) as file: + result: list[tuple[str, PredictedContigTaxonomy]] = [] + lines = filter(None, map(str.rstrip, file)) + header = next(lines, None) + if header is None or not header.startswith(PREDICTED_TAXONOMY_HEADER): + raise ValueError( + f"In predicted taxonomy file, expected header to begin with {repr(PREDICTED_TAXONOMY_HEADER)}" + ) + for line in lines: + (contigname, taxonomy, scores, *_) = line.split("\t") + contig_taxonomy = ContigTaxonomy.from_semicolon_sep( + taxonomy, force_canonical + ) + probs = np.array([float(i) for i in scores.split(";")], dtype=float) + result.append( + ( + contigname, + PredictedContigTaxonomy(contig_taxonomy, probs), + ) + ) + + return result + def write_as_tsv(self, file: IO[str], comp_metadata: CompositionMetaData): if self.refhash != comp_metadata.refhash: raise ValueError( "Refhash of comp_metadata and predicted taxonomy must match" ) assert self.nseqs == comp_metadata.nseqs - print("contigs\tpredictions\tlengths\tscores", file=file) + print(PREDICTED_TAXONOMY_HEADER, file=file) for i in range(self.nseqs): tax = self.contig_taxonomies[i] - ranks_str = ";".join(tax.tax.ranks) + ranks_str = ";".join(tax.contig_taxonomy.ranks) probs_str = ";".join([str(round(i, 5)) for i in tax.probs]) print( comp_metadata.identifiers[i], ranks_str, - comp_metadata.lengths[i], probs_str, file=file, sep="\t",