diff --git a/.github/workflows/cli_vamb.yml b/.github/workflows/cli_vamb.yml index 9018429d..80e07cde 100644 --- a/.github/workflows/cli_vamb.yml +++ b/.github/workflows/cli_vamb.yml @@ -28,7 +28,7 @@ jobs: cache-dependency-path: '**/pyproject.toml' - name: Download fixtures run: | - wget https://www.dropbox.com/scl/fi/xzc0tro7oe6tfm3igygpj/ci_data.zip\?rlkey\=xuv6b5eoynfryp4fba1kfp5jm\&st\=rjb1xccw\&dl\=0 -O ci_data.zip + wget https://www.dropbox.com/scl/fi/10tdf0w0kf70pf46hy8ks/ci_data.zip\?rlkey\=smlcinkesuwiw557zulgbb59l\&st\=hhokiqma\&dl\=0 -O ci_data.zip unzip -o ci_data.zip - name: Install dependencies run: | @@ -56,9 +56,6 @@ jobs: vamb taxometer --outdir outdir_taxometer --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pt 10 ls -la outdir_taxometer cat outdir_taxometer/log.txt - vamb taxometer --outdir outdir_taxometer_pred --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy outdir_taxometer/results_taxometer.tsv -pe 10 -pt 10 - ls -la outdir_taxometer_pred - cat outdir_taxometer/log.txt - name: Run k-means reclustering run: | vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --markers markers_mock.npz --algorithm kmeans --minfasta 200000 diff --git a/vamb/__main__.py b/vamb/__main__.py index 12c9a3e7..2d2f9d34 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -373,30 +373,61 @@ def __init__( self.refcheck = refcheck -class TaxonomyPath: - @classmethod - def from_args(cls, args: argparse.Namespace): - if args.taxonomy is None: - raise argparse.ArgumentTypeError( - "Cannot load taxonomy without specifying --taxonomy" - ) - return cls(typeasserted(args.taxonomy, Path)) +class TaxonomyBase: + __slots__ = ["path"] def __init__(self, path: Path): - self.path = check_existing_file(path) + self.path = path - def get_tax_path(self) -> Path: - return self.path + +class RefinedTaxonomy(TaxonomyBase): + pass + + +class UnrefinedTaxonomy(TaxonomyBase): + pass + + +def get_taxonomy(args: argparse.Namespace) -> Union[RefinedTaxonomy, UnrefinedTaxonomy]: + path = args.taxonomy + if path is None: + raise ValueError( + "Cannot load taxonomy for Taxometer without specifying --taxonomy" + ) + with open(check_existing_file(path)) as file: + try: + header = next(file).rstrip("\r\n") + except StopIteration: + header = None + + if header is None: + raise ValueError(f'Empty taxonomy path at "{path}"') + elif header == vamb.taxonomy.TAXONOMY_HEADER: + return UnrefinedTaxonomy(path) + elif header == vamb.taxonomy.PREDICTED_TAXONOMY_HEADER: + return RefinedTaxonomy(path) + else: + raise ValueError( + f'ERROR: When reading taxonomy file at "{path}", ' + f"the first line was not either {repr(vamb.taxonomy.TAXONOMY_HEADER)} " + f"or {repr(vamb.taxonomy.PREDICTED_TAXONOMY_HEADER)}'" + ) class TaxometerOptions: @classmethod def from_args(cls, args: argparse.Namespace): + tax = get_taxonomy(args) + if isinstance(tax, RefinedTaxonomy): + raise ValueError( + f'Attempted to run Taxometer to refine taxonomy at "{args.taxonomy}", ' + "but this file appears to already be an output of Taxometer" + ) return cls( GeneralOptions.from_args(args), CompositionOptions.from_args(args), AbundanceOptions.from_args(args), - TaxonomyPath.from_args(args), + tax, BasicTrainingOptions.from_args_taxometer(args), typeasserted(args.pred_softmax_threshold, float), args.ploss, @@ -407,7 +438,7 @@ def __init__( general: GeneralOptions, composition: CompositionOptions, abundance: AbundanceOptions, - taxonomy: TaxonomyPath, + taxonomy: UnrefinedTaxonomy, basic: BasicTrainingOptions, softmax_threshold: float, ploss: Union[ @@ -432,38 +463,6 @@ def __init__( self.basic = basic -class RefinableTaxonomyOptions: - __slots__ = ["path_or_tax_options"] - - @classmethod - def from_args(cls, args: argparse.Namespace): - predict = not typeasserted(args.no_predictor, bool) - - # TaxometerOptions have more options, but only the composition and the abundance - # can be omitted, so we only check those here. - if predict: - if not CompositionOptions.are_args_present( - args - ) or not AbundanceOptions.are_args_present(args): - raise ValueError( - "If `--no_predictor` is not passed, Taxometer is run to refine taxonomy, " - "and this requires composition input and abundance input to be passed in" - ) - return cls(TaxometerOptions.from_args(args)) - else: - return cls(TaxonomyPath.from_args(args)) - - def __init__(self, path_or_tax_options: Union[TaxonomyPath, TaxometerOptions]): - self.path_or_tax_options = path_or_tax_options - - def get_tax_path(self) -> Path: - p = self.path_or_tax_options - if isinstance(p, TaxonomyPath): - return p.get_tax_path() - else: - return p.taxonomy.get_tax_path() - - class MarkerPath: def __init__(self, path: Path): self.path = check_existing_file(path) @@ -504,8 +503,50 @@ def __init__(self, clusters: Path): class DBScanOptions: - def __init__(self, taxonomy_options: RefinableTaxonomyOptions, n_processes: int): - self.taxonomy_options = taxonomy_options + @classmethod + def from_args(cls, args: argparse.Namespace, n_threads: int): + tax = get_taxonomy(args) + predict = not typeasserted(args.no_predictor, bool) + if predict: + if isinstance(tax, RefinedTaxonomy): + logger.warning( + "Flag --no_predictor not set, but the taxonomy passed in " + "on --taxonomy is already refined. Skipped refinement." + ) + return cls(tax, n_threads) + else: + if not ( + CompositionOptions.are_args_present(args) + and AbundanceOptions.are_args_present(args) + ): + raise ValueError( + "Flag --no_predictor is not set, but abundance " + "or composition has not been passed in, so there is no information " + "to refine taxonomy with" + ) + tax_options = TaxometerOptions( + GeneralOptions.from_args(args), + CompositionOptions.from_args(args), + AbundanceOptions.from_args(args), + tax, + BasicTrainingOptions.from_args_taxometer(args), + typeasserted(args.pred_softmax_threshold, float), + args.ploss, + ) + return cls(tax_options, n_threads) + else: + return cls(tax, n_threads) + + def __init__( + self, + ops: Union[ + UnrefinedTaxonomy, + RefinedTaxonomy, + TaxometerOptions, + ], + n_processes: int, + ): + self.taxonomy = ops self.n_processes = n_processes @@ -714,15 +755,35 @@ def from_args(cls, args: argparse.Namespace): common = BinnerCommonOptions.from_args(args) basic = BasicTrainingOptions.from_args_vae(args) vae = VAEOptions.from_args(basic, args) - taxonomy = RefinableTaxonomyOptions.from_args(args) - return cls(common, vae, taxonomy) + taxonomy = get_taxonomy(args) + predict = not typeasserted(args.no_predictor, bool) + if predict: + if isinstance(taxonomy, RefinedTaxonomy): + logger.warning( + "Flag --no_predictor not set, but the taxonomy passed in " + "on --taxonomy is already refined. Skipped refinement." + ) + tax = taxonomy + else: + tax = TaxometerOptions( + common.general, + common.comp, + common.abundance, + taxonomy, + basic, + typeasserted(args.pred_softmax_threshold, float), + args.ploss, + ) + else: + tax = taxonomy + return cls(common, vae, tax) # The VAEVAE models share the same settings as the VAE model, so we just use VAEOptions def __init__( self, common: BinnerCommonOptions, vae: VAEOptions, - taxonomy: RefinableTaxonomyOptions, + taxonomy: Union[RefinedTaxonomy, TaxometerOptions, UnrefinedTaxonomy], ): self.common = common self.vae = vae @@ -768,15 +829,15 @@ def from_args(cls, args: argparse.Namespace): ) algorithm = KmeansOptions(clusters) elif args.algorithm == "dbscan": - tax = RefinableTaxonomyOptions.from_args(args) - algorithm = DBScanOptions(tax, general.n_threads) + algorithm = DBScanOptions.from_args(args, general.n_threads) else: assert False # no more algorithms + # Avoid loading composition again if already loaded in DBScanOptions if isinstance(algorithm, DBScanOptions) and isinstance( - algorithm.taxonomy_options.path_or_tax_options, TaxometerOptions + algorithm.taxonomy, tuple ): - comp = algorithm.taxonomy_options.path_or_tax_options.composition + comp = algorithm.taxonomy[1] else: comp = CompositionOptions.from_args(args) @@ -1456,24 +1517,26 @@ def run_vaevae(opt: BinTaxVambOptions): composition.metadata.lengths, composition.metadata.identifiers, ) - if isinstance(opt.taxonomy.path_or_tax_options, TaxometerOptions): - logger.info("Predicting missing values from taxonomy") + if isinstance(opt.taxonomy, TaxometerOptions): predicted_contig_taxonomies = predict_taxonomy( comp_metadata=composition.metadata, abundance=abundance, tnfs=tnfs, lengths=lengths, out_dir=opt.common.general.out_dir, - taxonomy_options=opt.taxonomy.path_or_tax_options, + taxonomy_options=opt.taxonomy, cuda=opt.common.general.cuda, ) contig_taxonomies = predicted_contig_taxonomies.to_taxonomy() + elif isinstance(opt.taxonomy, RefinedTaxonomy): + logger.info("Loading already-refined taxonomy from file") + contig_taxonomies = vamb.taxonomy.Taxonomy.from_refined_file( + opt.taxonomy.path, composition.metadata, False + ) else: - logger.info("Not predicting the taxonomy") + logger.info("Loading unrefined taxonomy from file") contig_taxonomies = vamb.taxonomy.Taxonomy.from_file( - opt.taxonomy.path_or_tax_options.path, - composition.metadata, - False, + opt.taxonomy.path, composition.metadata, False ) nodes, ind_nodes, table_parent = vamb.taxvamb_encode.make_graph( @@ -1569,9 +1632,6 @@ def run_vaevae(opt: BinTaxVambOptions): ) -# TODO: The whole data flow around predict_taxonomy needs to change. -# Ideally, we should have a "get taxonomy" function that loads, possibly refines, -# and then returns the taxonomy object. def run_reclustering(opt: ReclusteringOptions): composition = calc_tnf( opt.composition, @@ -1587,8 +1647,8 @@ def run_reclustering(opt: ReclusteringOptions): if isinstance(alg, DBScanOptions): # If we should refine or not. - if isinstance(alg.taxonomy_options.path_or_tax_options, TaxometerOptions): - taxopt = alg.taxonomy_options.path_or_tax_options + if isinstance(alg.taxonomy, TaxometerOptions): + taxopt = alg.taxonomy abundance = calc_abundance( taxopt.abundance, taxopt.general.out_dir, @@ -1607,11 +1667,16 @@ def run_reclustering(opt: ReclusteringOptions): ) taxonomy = predicted_tax.to_taxonomy() else: - tax_path = alg.taxonomy_options.path_or_tax_options.path - logger.info(f'Loading taxonomy from file "{tax_path}"') - taxonomy = vamb.taxonomy.Taxonomy.from_file( - tax_path, composition.metadata, True - ) + logger.info(f'Loading taxonomy from file "{alg.taxonomy.path}"') + if isinstance(alg.taxonomy, UnrefinedTaxonomy): + taxonomy = vamb.taxonomy.Taxonomy.from_file( + alg.taxonomy.path, composition.metadata, True + ) + else: + taxonomy = vamb.taxonomy.Taxonomy.from_refined_file( + alg.taxonomy.path, composition.metadata, True + ) + instantiated_alg = vamb.reclustering.DBScanAlgorithm( composition.metadata, taxonomy, opt.general.n_threads ) diff --git a/vamb/reclustering.py b/vamb/reclustering.py index 0c4cfcc8..130ed29e 100644 --- a/vamb/reclustering.py +++ b/vamb/reclustering.py @@ -236,12 +236,6 @@ def get_completeness_contamination(counts: np.ndarray) -> tuple[float, float]: return (completeness, contamination) -# An arbitrary score of a bin, where higher numbers is better. -# completeness - 5 * contamination is used by the CheckM group as a heuristic. -def score_from_comp_cont(comp_cont: tuple[float, float]) -> float: - return comp_cont[0] - 5 * comp_cont[1] - - def recluster_dbscan( taxonomy: Taxonomy, latent: np.ndarray, @@ -251,15 +245,31 @@ def recluster_dbscan( ) -> list[set[ContigId]]: # Since DBScan is computationally expensive, and scales poorly with the number # of contigs, we use taxonomy to only cluster within each genus - result: list[set[ContigId]] = [] - for indices in group_indices_by_genus(taxonomy): - genus_latent = latent[indices] - genus_clusters = dbscan_genus( - genus_latent, indices, contiglengths[indices], markers, num_processes - ) - result.extend(genus_clusters) + n_worse_in_row = 0 + genera_indices = group_indices_by_genus(taxonomy) + best_score = 0 + best_bins: list[set[ContigId]] = [] + for eps in EPS_VALUES: + bins: list[set[ContigId]] = [] + for indices in genera_indices: + genus_clusters = dbscan_genus( + latent[indices], indices, contiglengths[indices], num_processes, eps + ) + bins.extend(genus_clusters) - return result + score = count_good_genomes(bins, markers) + if best_score == 0 or score > best_score: + best_bins = bins + best_score = score + + if score >= best_score: + n_worse_in_row = 0 + else: + n_worse_in_row += 1 + if n_worse_in_row > 2: + break + + return best_bins # DBScan within the subset of contigs that are annotated with a single genus @@ -267,8 +277,8 @@ def dbscan_genus( latent_of_genus: np.ndarray, original_indices: np.ndarray, contiglengths_of_genus: np.ndarray, - markers: Markers, num_processes: int, + eps: float, ) -> list[set[ContigId]]: assert len(latent_of_genus) == len(original_indices) == len(contiglengths_of_genus) # Precompute distance matrix. This is O(N^2), but DBScan is even worse, @@ -278,41 +288,21 @@ def dbscan_genus( distance_matrix = pairwise_distances( latent_of_genus, latent_of_genus, metric="cosine" ) - best_bins: tuple[int, dict[int, set[ContigId]]] = (-1, dict()) # The DBScan approach works by blindly clustering with different eps values # (a critical parameter for DBscan), and then using SCGs to select the best # subset of clusters. # It's ugly and wasteful, but it does work. - n_worse_in_row = 0 - for eps in EPS_VALUES: - print(f"Running eps: {eps}") - dbscan = DBSCAN( - eps=eps, - min_samples=5, - n_jobs=num_processes, - metric="precomputed", - ) - dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus) - these_bins: dict[int, set[ContigId]] = defaultdict(set) - for original_index, bin_index in zip(original_indices, dbscan.labels_): - these_bins[bin_index].add(ContigId(original_index)) - score = count_good_genomes(these_bins.values(), markers) - print(f"Score: {score}, best: {best_bins[0]}, worse: {n_worse_in_row}") - if score > best_bins[0]: - best_bins = (score, these_bins) - n_worse_in_row = 0 - elif score < best_bins[0]: - # This is an elif statement and not an if statement because we don't - # want e.g. a series of [1,1,1,1,1] genomes to be considered a performance - # degradation leading to early exit, when in fact, it probably means - # that eps is too low and should be increased. - n_worse_in_row += 1 - # If we see 3 worse clusterings in a row, we exit early. - - if n_worse_in_row > 2: - break - - return list(best_bins[1].values()) + dbscan = DBSCAN( + eps=eps, + min_samples=5, + n_jobs=num_processes, + metric="precomputed", + ) + dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus) + bins: dict[int, set[ContigId]] = defaultdict(set) + for original_index, bin_index in zip(original_indices, dbscan.labels_): + bins[bin_index].add(ContigId(original_index)) + return list(bins.values()) def count_good_genomes(binning: Iterable[Iterable[ContigId]], markers: Markers) -> int: diff --git a/vamb/taxonomy.py b/vamb/taxonomy.py index 7e88919d..7a9e1d63 100644 --- a/vamb/taxonomy.py +++ b/vamb/taxonomy.py @@ -3,6 +3,9 @@ from vamb.parsecontigs import CompositionMetaData import numpy as np +TAXONOMY_HEADER = "contigs\tpredictions" +PREDICTED_TAXONOMY_HEADER = "contigs\tpredictions\tscores" + class ContigTaxonomy: """ @@ -56,6 +59,14 @@ def from_file( observed = cls.parse_tax_file(tax_file, is_canonical) return cls.from_observed(observed, metadata, is_canonical) + @classmethod + def from_refined_file( + cls, tax_file: Path, metadata: CompositionMetaData, is_canonical: bool + ): + observed = PredictedTaxonomy.parse_tax_file(tax_file, is_canonical) + observed = [(name, tax.contig_taxonomy) for (name, tax) in observed] + return cls.from_observed(observed, metadata, is_canonical) + @classmethod def from_observed( cls, @@ -100,13 +111,12 @@ def parse_tax_file( ) -> list[tuple[str, ContigTaxonomy]]: with open(path) as file: result: list[tuple[str, ContigTaxonomy]] = [] - lines = filter(None, map(str.rstrip, file)) - header = next(lines, None) - if header is None or not header.startswith("contigs\tpredictions"): + header = next(file, None) + if header is None or not header.startswith(TAXONOMY_HEADER): raise ValueError( - 'In taxonomy file, expected header to begin with "contigs\\tpredictions"' + f"In taxonomy file, expected header to begin with {repr(TAXONOMY_HEADER)}" ) - for line in lines: + for line in file: (contigname, taxonomy, *_) = line.split("\t") result.append( ( @@ -124,7 +134,10 @@ class PredictedContigTaxonomy: def __init__(self, tax: ContigTaxonomy, probs: np.ndarray): if len(probs) != len(tax.ranks): raise ValueError("The length of probs must equal that of ranks") - self.tax = tax + # Due to floating point errors, the probabilities may be slightly outside of 0 or 1. + # We could perhaps validate the values, but that's not likely to be necessary. + np.clip(probs, a_min=0.0, a_max=1.0, out=probs) + self.contig_taxonomy = tax self.probs = probs @@ -147,28 +160,56 @@ def __init__( self.is_canonical = is_canonical def to_taxonomy(self) -> Taxonomy: - lst: list[Optional[ContigTaxonomy]] = [p.tax for p in self.contig_taxonomies] + lst: list[Optional[ContigTaxonomy]] = [ + p.contig_taxonomy for p in self.contig_taxonomies + ] return Taxonomy(lst, self.refhash, self.is_canonical) @property def nseqs(self) -> int: return len(self.contig_taxonomies) + @staticmethod + def parse_tax_file( + path: Path, force_canonical: bool + ) -> list[tuple[str, PredictedContigTaxonomy]]: + with open(path) as file: + result: list[tuple[str, PredictedContigTaxonomy]] = [] + lines = filter(None, map(str.rstrip, file)) + header = next(lines, None) + if header is None or not header.startswith(PREDICTED_TAXONOMY_HEADER): + raise ValueError( + f"In predicted taxonomy file, expected header to begin with {repr(PREDICTED_TAXONOMY_HEADER)}" + ) + for line in lines: + (contigname, taxonomy, scores, *_) = line.split("\t") + contig_taxonomy = ContigTaxonomy.from_semicolon_sep( + taxonomy, force_canonical + ) + probs = np.array([float(i) for i in scores.split(";")], dtype=float) + result.append( + ( + contigname, + PredictedContigTaxonomy(contig_taxonomy, probs), + ) + ) + + return result + def write_as_tsv(self, file: IO[str], comp_metadata: CompositionMetaData): if self.refhash != comp_metadata.refhash: raise ValueError( "Refhash of comp_metadata and predicted taxonomy must match" ) assert self.nseqs == comp_metadata.nseqs - print("contigs\tpredictions\tlengths\tscores", file=file) + print(PREDICTED_TAXONOMY_HEADER, file=file) for i in range(self.nseqs): tax = self.contig_taxonomies[i] - ranks_str = ";".join(tax.tax.ranks) + ranks_str = ";".join(tax.contig_taxonomy.ranks) probs_str = ";".join([str(round(i, 5)) for i in tax.probs]) print( comp_metadata.identifiers[i], ranks_str, - comp_metadata.lengths[i], probs_str, file=file, sep="\t",