diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 0277d1f5..4df32579 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -481,6 +481,9 @@ def training_step( else: loss = self.criterion(logits, targets, self.current_epoch) + # compensate for mean on the estimators + loss *= self.num_estimators + self.log("train_loss", loss) return loss