Skip to content

Commit

Permalink
WIP: DBScan
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Nov 19, 2024
1 parent 7850525 commit 90e195e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 116 deletions.
66 changes: 0 additions & 66 deletions src/taxobench.py

This file was deleted.

82 changes: 36 additions & 46 deletions vamb/reclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,6 @@ def get_completeness_contamination(counts: np.ndarray) -> tuple[float, float]:
return (completeness, contamination)


# An arbitrary score of a bin, where higher numbers is better.
# completeness - 5 * contamination is used by the CheckM group as a heuristic.
def score_from_comp_cont(comp_cont: tuple[float, float]) -> float:
return comp_cont[0] - 5 * comp_cont[1]


def recluster_dbscan(
taxonomy: Taxonomy,
latent: np.ndarray,
Expand All @@ -251,24 +245,40 @@ def recluster_dbscan(
) -> list[set[ContigId]]:
# Since DBScan is computationally expensive, and scales poorly with the number
# of contigs, we use taxonomy to only cluster within each genus
result: list[set[ContigId]] = []
for indices in group_indices_by_genus(taxonomy):
genus_latent = latent[indices]
genus_clusters = dbscan_genus(
genus_latent, indices, contiglengths[indices], markers, num_processes
)
result.extend(genus_clusters)
n_worse_in_row = 0
genera_indices = group_indices_by_genus(taxonomy)
best_score = 0
best_bins: list[set[ContigId]] = []
for eps in EPS_VALUES:
bins: list[set[ContigId]] = []
for indices in genera_indices:
genus_clusters = dbscan_genus(
latent[indices], indices, contiglengths[indices], num_processes, eps
)
bins.extend(genus_clusters)

return result
score = count_good_genomes(bins, markers)
if best_score == 0 or score > best_score:
best_bins = bins
best_score = score

if score >= best_score:
n_worse_in_row = 0
else:
n_worse_in_row += 1
if n_worse_in_row > 2:
break

return best_bins


# DBScan within the subset of contigs that are annotated with a single genus
def dbscan_genus(
latent_of_genus: np.ndarray,
original_indices: np.ndarray,
contiglengths_of_genus: np.ndarray,
markers: Markers,
num_processes: int,
eps: float,
) -> list[set[ContigId]]:
assert len(latent_of_genus) == len(original_indices) == len(contiglengths_of_genus)
# Precompute distance matrix. This is O(N^2), but DBScan is even worse,
Expand All @@ -278,41 +288,21 @@ def dbscan_genus(
distance_matrix = pairwise_distances(
latent_of_genus, latent_of_genus, metric="cosine"
)
best_bins: tuple[int, dict[int, set[ContigId]]] = (-1, dict())
# The DBScan approach works by blindly clustering with different eps values
# (a critical parameter for DBscan), and then using SCGs to select the best
# subset of clusters.
# It's ugly and wasteful, but it does work.
n_worse_in_row = 0
for eps in EPS_VALUES:
print(f"Running eps: {eps}")
dbscan = DBSCAN(
eps=eps,
min_samples=5,
n_jobs=num_processes,
metric="precomputed",
)
dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus)
these_bins: dict[int, set[ContigId]] = defaultdict(set)
for original_index, bin_index in zip(original_indices, dbscan.labels_):
these_bins[bin_index].add(ContigId(original_index))
score = count_good_genomes(these_bins.values(), markers)
print(f"Score: {score}, best: {best_bins[0]}, worse: {n_worse_in_row}")
if score > best_bins[0]:
best_bins = (score, these_bins)
n_worse_in_row = 0
elif score < best_bins[0]:
# This is an elif statement and not an if statement because we don't
# want e.g. a series of [1,1,1,1,1] genomes to be considered a performance
# degradation leading to early exit, when in fact, it probably means
# that eps is too low and should be increased.
n_worse_in_row += 1
# If we see 3 worse clusterings in a row, we exit early.

if n_worse_in_row > 2:
break

return list(best_bins[1].values())
dbscan = DBSCAN(
eps=eps,
min_samples=5,
n_jobs=num_processes,
metric="precomputed",
)
dbscan.fit(distance_matrix, sample_weight=contiglengths_of_genus)
bins: dict[int, set[ContigId]] = defaultdict(set)
for original_index, bin_index in zip(original_indices, dbscan.labels_):
bins[bin_index].add(ContigId(original_index))
return list(bins.values())


def count_good_genomes(binning: Iterable[Iterable[ContigId]], markers: Markers) -> int:
Expand Down
12 changes: 8 additions & 4 deletions vamb/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,12 @@ def parse_tax_file(
) -> list[tuple[str, ContigTaxonomy]]:
with open(path) as file:
result: list[tuple[str, ContigTaxonomy]] = []
lines = filter(None, map(str.rstrip, file))
header = next(lines, None)
header = next(file, None)
if header is None or not header.startswith("contigs\tpredictions"):
raise ValueError(
'In taxonomy file, expected header to begin with "contigs\\tpredictions"'
)
for line in lines:
for line in file:
(contigname, taxonomy, *_) = line.split("\t")
result.append(
(
Expand All @@ -124,6 +123,9 @@ class PredictedContigTaxonomy:
def __init__(self, tax: ContigTaxonomy, probs: np.ndarray):
if len(probs) != len(tax.ranks):
raise ValueError("The length of probs must equal that of ranks")
# Due to floating point errors, the probabilities may be slightly outside of 0 or 1.
# We could perhaps validate the values, but that's not likely to be necessary.
np.clip(probs, a_min=0.0, a_max=1.0, out=probs)
self.contig_taxonomy = tax
self.probs = probs

Expand Down Expand Up @@ -158,7 +160,7 @@ def nseqs(self) -> int:

@staticmethod
def parse_tax_file(
path: Path, force_canonical: bool
path: Path, minlen: int, force_canonical: bool
) -> list[tuple[str, int, PredictedContigTaxonomy]]:
with open(path) as file:
result: list[tuple[str, int, PredictedContigTaxonomy]] = []
Expand All @@ -173,6 +175,8 @@ def parse_tax_file(
for line in lines:
(contigname, taxonomy, lengthstr, scores, *_) = line.split("\t")
length = int(lengthstr)
if length < minlen:
continue
contig_taxonomy = ContigTaxonomy.from_semicolon_sep(
taxonomy, force_canonical
)
Expand Down

0 comments on commit 90e195e

Please sign in to comment.