From 90e195ec50ebcfd4557496e947cd26cb04962fa8 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Tue, 19 Nov 2024 16:40:52 +0100 Subject: [PATCH] WIP: DBScan --- src/taxobench.py | 66 ----------------------------------- vamb/reclustering.py | 82 +++++++++++++++++++------------------------- vamb/taxonomy.py | 12 ++++--- 3 files changed, 44 insertions(+), 116 deletions(-) delete mode 100644 src/taxobench.py diff --git a/src/taxobench.py b/src/taxobench.py deleted file mode 100644 index e8a4fbd0..00000000 --- a/src/taxobench.py +++ /dev/null @@ -1,66 +0,0 @@ -from math import log -from pathlib import Path - -from vamb.taxonomy import ( - ContigTaxonomy, - PredictedContigTaxonomy, - Taxonomy, - PredictedTaxonomy, -) - -# The score is computed as the log of the probability assigned to the right species. -# At any clade, we assume there are e^2+1 children, and all the children not predicted -# have been given the same score. - -# Examples: -# 1) Correct guess at species level. The predictor predicts the species with score 0.8: -# Result: log(0.8) - -# 2) Correct guess at genus level; wrong at species level with score 0.8: -# The remaining score of 0.8 is divided by the remaining e^2 children: -# Result: log(0.2 / e^2) = log(0.2) - 2 - -# 3) Correct guess at family level; wrong at genus level with score 0.8: -# The remaining score of 0.2 is divided among e^2 children, each whom have e^2+1 children. -# Result: log(0.2 / (e^2 * (e^2 + 1))) - we round this off to log(0.2 / (e^2 * e^2)) = log(0.2) - 4 - -# So: Result is: If correct, log of last score. If N levels are incorrect, it's log(1 - score at first level) - 2N - - -# INVARIANT: Must be canonical -def pad_tax(x: list): - x = x.copy() - if len(x) > 6: - return x - x.extend([None] * (7 - len(x))) - x.reverse() - return x - - -def score(true: ContigTaxonomy, pred: PredictedContigTaxonomy) -> float: - for rank, ((true_tax, pred_tax, prob)) in enumerate( - zip(true.ranks, pred.contig_taxonomy.ranks, pred.probs) - ): - if true_tax != pred_tax: - wrong_ranks = 7 - rank - return log(1 - prob) - 2 * wrong_ranks - - for n_wrong_minus_one, (truerank, predrank, prob) in enumerate( - zip(pad_tax(true.ranks), pad_tax(pred.contig_taxonomy.ranks), pred.probs) - ): - if truerank != predrank: - return log(1 - prob) - 2 * (n_wrong_minus_one + 1) - return log(pred.probs[-1]) - - -def load_scores(truth_path: Path, pred_path: Path) -> list[tuple[str, int, float]]: - truth = dict(Taxonomy.parse_tax_file(truth_path, True)) - pred = PredictedTaxonomy.parse_tax_file(pred_path, True) - return [ - (name, length, score(truth[name], contig_pred)) - for (name, length, contig_pred) in pred - ] - - -def weighted_score(lst: list[tuple[str, int, float]]) -> float: - return sum(i[1] * i[2] for i in lst) / sum(i[1] for i in lst) diff --git a/vamb/reclustering.py b/vamb/reclustering.py index 0c4cfcc8..130ed29e 100644 --- a/vamb/reclustering.py +++ b/vamb/reclustering.py @@ -236,12 +236,6 @@ def get_completeness_contamination(counts: np.ndarray) -> tuple[float, float]: return (completeness, contamination) -# An arbitrary score of a bin, where higher numbers is better. -# completeness - 5 * contamination is used by the CheckM group as a heuristic. -def score_from_comp_cont(comp_cont: tuple[float, float]) -> float: - return comp_cont[0] - 5 * comp_cont[1] - - def recluster_dbscan( taxonomy: Taxonomy, latent: np.ndarray, @@ -251,15 +245,31 @@ def recluster_dbscan( ) -> list[set[ContigId]]: # Since DBScan is computationally expensive, and scales poorly with the number # of contigs, we use taxonomy to only cluster within each genus - result: list[set[ContigId]] = [] - for indices in group_indices_by_genus(taxonomy): - genus_latent = latent[indices] - genus_clusters = dbscan_genus( - genus_latent, indices, contiglengths[indices], markers, num_processes - ) - result.extend(genus_clusters) + n_worse_in_row = 0 + genera_indices = group_indices_by_genus(taxonomy) + best_score = 0 + best_bins: list[set[ContigId]] = [] + for eps in EPS_VALUES: + bins: list[set[ContigId]] = [] + for indices in genera_indices: + genus_clusters = dbscan_genus( + latent[indices], indices, contiglengths[indices], num_processes, eps + ) + bins.extend(genus_clusters) - return result + score = count_good_genomes(bins, markers) + if best_score == 0 or score > best_score: + best_bins = bins + best_score = score + + if score >= best_score: + n_worse_in_row = 0 + else: + n_worse_in_row += 1 + if n_worse_in_row > 2: + break + + return best_bins # DBScan within the subset of contigs that are annotated with a single genus @@ -267,8 +277,8 @@ def dbscan_genus( latent_of_genus: np.ndarray, original_indices: np.ndarray, contiglengths_of_genus: np.ndarray, - markers: Markers, num_processes: int, + eps: float, ) -> list[set[ContigId]]: assert len(latent_of_genus) == len(original_indices) == len(contiglengths_of_genus) # Precompute distance matrix. This is O(N^2), but DBScan is even worse, @@ -278,41 +288,21 @@ def dbscan_genus( distance_matrix = pairwise_distances( latent_of_genus, latent_of_genus, metric="cosine" ) - best_bins: tuple[int, dict[int, set[ContigId]]] = (-1, dict()) # The DBScan approach works by blindly clustering with different eps values # (a critical parameter for DBscan), and then using SCGs to select the best # subset of clusters. # It's ugly and wasteful, but it does work. - n_worse_in_row = 0 - for eps in EPS_VALUES: - print(f"Running eps: {eps}") - dbscan = DBSCAN( - eps=eps, - min_samples=5, - n_jobs=num_processes, - metric="precomputed", - ) - dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus) - these_bins: dict[int, set[ContigId]] = defaultdict(set) - for original_index, bin_index in zip(original_indices, dbscan.labels_): - these_bins[bin_index].add(ContigId(original_index)) - score = count_good_genomes(these_bins.values(), markers) - print(f"Score: {score}, best: {best_bins[0]}, worse: {n_worse_in_row}") - if score > best_bins[0]: - best_bins = (score, these_bins) - n_worse_in_row = 0 - elif score < best_bins[0]: - # This is an elif statement and not an if statement because we don't - # want e.g. a series of [1,1,1,1,1] genomes to be considered a performance - # degradation leading to early exit, when in fact, it probably means - # that eps is too low and should be increased. - n_worse_in_row += 1 - # If we see 3 worse clusterings in a row, we exit early. - - if n_worse_in_row > 2: - break - - return list(best_bins[1].values()) + dbscan = DBSCAN( + eps=eps, + min_samples=5, + n_jobs=num_processes, + metric="precomputed", + ) + dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus) + bins: dict[int, set[ContigId]] = defaultdict(set) + for original_index, bin_index in zip(original_indices, dbscan.labels_): + bins[bin_index].add(ContigId(original_index)) + return list(bins.values()) def count_good_genomes(binning: Iterable[Iterable[ContigId]], markers: Markers) -> int: diff --git a/vamb/taxonomy.py b/vamb/taxonomy.py index 2a0b1f14..d7902539 100644 --- a/vamb/taxonomy.py +++ b/vamb/taxonomy.py @@ -100,13 +100,12 @@ def parse_tax_file( ) -> list[tuple[str, ContigTaxonomy]]: with open(path) as file: result: list[tuple[str, ContigTaxonomy]] = [] - lines = filter(None, map(str.rstrip, file)) - header = next(lines, None) + header = next(file, None) if header is None or not header.startswith("contigs\tpredictions"): raise ValueError( 'In taxonomy file, expected header to begin with "contigs\\tpredictions"' ) - for line in lines: + for line in file: (contigname, taxonomy, *_) = line.split("\t") result.append( ( @@ -124,6 +123,9 @@ class PredictedContigTaxonomy: def __init__(self, tax: ContigTaxonomy, probs: np.ndarray): if len(probs) != len(tax.ranks): raise ValueError("The length of probs must equal that of ranks") + # Due to floating point errors, the probabilities may be slightly outside of 0 or 1. + # We could perhaps validate the values, but that's not likely to be necessary. + np.clip(probs, a_min=0.0, a_max=1.0, out=probs) self.contig_taxonomy = tax self.probs = probs @@ -158,7 +160,7 @@ def nseqs(self) -> int: @staticmethod def parse_tax_file( - path: Path, force_canonical: bool + path: Path, minlen: int, force_canonical: bool ) -> list[tuple[str, int, PredictedContigTaxonomy]]: with open(path) as file: result: list[tuple[str, int, PredictedContigTaxonomy]] = [] @@ -173,6 +175,8 @@ def parse_tax_file( for line in lines: (contigname, taxonomy, lengthstr, scores, *_) = line.split("\t") length = int(lengthstr) + if length < minlen: + continue contig_taxonomy = ContigTaxonomy.from_semicolon_sep( taxonomy, force_canonical )