diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 6325e7b7c..dcb9d8764 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -31,6 +31,7 @@ from scipy.optimize import linear_sum_assignment from torch.utils.data import DataLoader, Subset, random_split from lightning.pytorch.callbacks import EarlyStopping +from torchmetrics.aggregation import MeanMetric from torchmetrics.classification import (MultilabelAccuracy, MultilabelJaccardIndex) @@ -278,10 +279,10 @@ def validation_step(self, batch, batch_idx): costs = np.sort(costs)[:len(y_curves[line_cls][0])] penalty = np.full(diff, 8.0) costs = np.concatenate([costs, penalty]) - self.val_line_dist.update(costs/8.0) + self.val_line_mean_dist.update(costs/8.0) # no line output else: - self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0]))) + self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0]))) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: @@ -289,7 +290,7 @@ def on_validation_epoch_end(self): mean_accuracy = self.val_region_mean_accuracy.compute() mean_iu = self.val_region_mean_iu.compute() freq_iu = self.val_region_freq_iu.compute() - mean_line_dist = self.val_line_dist.compute() + line_mean_dist = self.val_line_mean_dist.compute() if mean_iu > self.best_metric: logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') @@ -302,7 +303,7 @@ def on_validation_epoch_end(self): self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_line_dist', mean_line_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_line_mean_dist', line_mean_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) # reset metrics even if sanity checking @@ -432,7 +433,8 @@ def setup(self, stage: Optional[str] = None): torch.set_num_threads(max(self.num_workers, 1)) # set up validation metrics after output classes have been determined - # baseline metrics + # baseline metric + self.val_line_mean_dist = MeanMetric() # region metrics num_regions = len(self.val_set.dataset.class_mapping['regions']) self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions)