Skip to content

Commit

Permalink
fix masking in macro-f1-score
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Jan 4, 2024
1 parent 5fc02d1 commit 7bbe5c1
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,21 @@ def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
self.threshold = threshold

def update(self, preds: torch.Tensor, labels: torch.Tensor):
tps = torch.sum(torch.logical_and(preds > self.threshold, labels), dim=0)
tps = torch.sum(
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
)
self.true_positives += tps
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
self.positive_labels += torch.sum(labels, dim=0)

def compute(self):
mask = torch.logical_and(self.positive_predictions, self.positive_labels)
# ignore classes without positive labels
# classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
# which is propagated to the classwise_f1 and then turned into 0
mask = self.positive_labels != 0
precision = self.true_positives[mask] / self.positive_predictions[mask]
recall = self.true_positives[mask] / self.positive_labels[mask]
classwise_f1 = 2 * precision * recall / (precision + recall)
# if precision and recall are 0, set f1 to 0 as well
# if (precision and recall are 0) or (precision is nan), set f1 to 0
classwise_f1 = classwise_f1.nan_to_num()
return torch.mean(classwise_f1)

0 comments on commit 7bbe5c1

Please sign in to comment.