Skip to content

Commit

Permalink
WIP: Fixup tax
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Nov 19, 2024
1 parent 003735d commit 7850525
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 3 deletions.
66 changes: 66 additions & 0 deletions src/taxobench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from math import log
from pathlib import Path

from vamb.taxonomy import (
ContigTaxonomy,
PredictedContigTaxonomy,
Taxonomy,
PredictedTaxonomy,
)

# The score is computed as the log of the probability assigned to the right species.
# At any clade, we assume there are e^2+1 children, and all the children not predicted
# have been given the same score.

# Examples:
# 1) Correct guess at species level. The predictor predicts the species with score 0.8:
# Result: log(0.8)

# 2) Correct guess at genus level; wrong at species level with score 0.8:
# The remaining score of 0.8 is divided by the remaining e^2 children:
# Result: log(0.2 / e^2) = log(0.2) - 2

# 3) Correct guess at family level; wrong at genus level with score 0.8:
# The remaining score of 0.2 is divided among e^2 children, each whom have e^2+1 children.
# Result: log(0.2 / (e^2 * (e^2 + 1))) - we round this off to log(0.2 / (e^2 * e^2)) = log(0.2) - 4

# So: Result is: If correct, log of last score. If N levels are incorrect, it's log(1 - score at first level) - 2N


# INVARIANT: Must be canonical
def pad_tax(x: list):
x = x.copy()
if len(x) > 6:
return x
x.extend([None] * (7 - len(x)))
x.reverse()
return x


def score(true: ContigTaxonomy, pred: PredictedContigTaxonomy) -> float:
for rank, ((true_tax, pred_tax, prob)) in enumerate(
zip(true.ranks, pred.contig_taxonomy.ranks, pred.probs)
):
if true_tax != pred_tax:
wrong_ranks = 7 - rank
return log(1 - prob) - 2 * wrong_ranks

for n_wrong_minus_one, (truerank, predrank, prob) in enumerate(
zip(pad_tax(true.ranks), pad_tax(pred.contig_taxonomy.ranks), pred.probs)
):
if truerank != predrank:
return log(1 - prob) - 2 * (n_wrong_minus_one + 1)
return log(pred.probs[-1])


def load_scores(truth_path: Path, pred_path: Path) -> list[tuple[str, int, float]]:
truth = dict(Taxonomy.parse_tax_file(truth_path, True))
pred = PredictedTaxonomy.parse_tax_file(pred_path, True)
return [
(name, length, score(truth[name], contig_pred))
for (name, length, contig_pred) in pred
]


def weighted_score(lst: list[tuple[str, int, float]]) -> float:
return sum(i[1] * i[2] for i in lst) / sum(i[1] for i in lst)
39 changes: 36 additions & 3 deletions vamb/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class PredictedContigTaxonomy:
def __init__(self, tax: ContigTaxonomy, probs: np.ndarray):
if len(probs) != len(tax.ranks):
raise ValueError("The length of probs must equal that of ranks")
self.tax = tax
self.contig_taxonomy = tax
self.probs = probs


Expand All @@ -147,13 +147,46 @@ def __init__(
self.is_canonical = is_canonical

def to_taxonomy(self) -> Taxonomy:
lst: list[Optional[ContigTaxonomy]] = [p.tax for p in self.contig_taxonomies]
lst: list[Optional[ContigTaxonomy]] = [
p.contig_taxonomy for p in self.contig_taxonomies
]
return Taxonomy(lst, self.refhash, self.is_canonical)

@property
def nseqs(self) -> int:
return len(self.contig_taxonomies)

@staticmethod
def parse_tax_file(
path: Path, force_canonical: bool
) -> list[tuple[str, int, PredictedContigTaxonomy]]:
with open(path) as file:
result: list[tuple[str, int, PredictedContigTaxonomy]] = []
lines = filter(None, map(str.rstrip, file))
header = next(lines, None)
if header is None or not header.startswith(
"contigs\tpredictions\tlengths\tscores"
):
raise ValueError(
'In predicted taxonomy file, expected header to begin with "contigs\\tpredictions\\tlengths\\tscores"'
)
for line in lines:
(contigname, taxonomy, lengthstr, scores, *_) = line.split("\t")
length = int(lengthstr)
contig_taxonomy = ContigTaxonomy.from_semicolon_sep(
taxonomy, force_canonical
)
probs = np.array([float(i) for i in scores.split(";")], dtype=float)
result.append(
(
contigname,
length,
PredictedContigTaxonomy(contig_taxonomy, probs),
)
)

return result

def write_as_tsv(self, file: IO[str], comp_metadata: CompositionMetaData):
if self.refhash != comp_metadata.refhash:
raise ValueError(
Expand All @@ -163,7 +196,7 @@ def write_as_tsv(self, file: IO[str], comp_metadata: CompositionMetaData):
print("contigs\tpredictions\tlengths\tscores", file=file)
for i in range(self.nseqs):
tax = self.contig_taxonomies[i]
ranks_str = ";".join(tax.tax.ranks)
ranks_str = ";".join(tax.contig_taxonomy.ranks)
probs_str = ";".join([str(round(i, 5)) for i in tax.probs])
print(
comp_metadata.identifiers[i],
Expand Down

0 comments on commit 7850525

Please sign in to comment.