Skip to content

Commit

Permalink
Create metric on model
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 29, 2024
1 parent fee7f13 commit a97cf34
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader, Subset, random_split
from lightning.pytorch.callbacks import EarlyStopping
from torchmetrics.aggregation import MeanMetric
from torchmetrics.classification import (MultilabelAccuracy,
MultilabelJaccardIndex)

Expand Down Expand Up @@ -278,18 +279,18 @@ def validation_step(self, batch, batch_idx):
costs = np.sort(costs)[:len(y_curves[line_cls][0])]
penalty = np.full(diff, 8.0)
costs = np.concatenate([costs, penalty])
self.val_line_dist.update(costs/8.0)
self.val_line_mean_dist.update(costs/8.0)
# no line output
else:
self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0])))
self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0])))

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
pixel_accuracy = self.val_region_px_accuracy.compute()
mean_accuracy = self.val_region_mean_accuracy.compute()
mean_iu = self.val_region_mean_iu.compute()
freq_iu = self.val_region_freq_iu.compute()
mean_line_dist = self.val_line_dist.compute()
line_mean_dist = self.val_line_mean_dist.compute()

if mean_iu > self.best_metric:
logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})')
Expand All @@ -302,7 +303,7 @@ def on_validation_epoch_end(self):
self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_line_dist', mean_line_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_line_mean_dist', line_mean_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True)

# reset metrics even if sanity checking
Expand Down Expand Up @@ -432,7 +433,8 @@ def setup(self, stage: Optional[str] = None):
torch.set_num_threads(max(self.num_workers, 1))

# set up validation metrics after output classes have been determined
# baseline metrics
# baseline metric
self.val_line_mean_dist = MeanMetric()
# region metrics
num_regions = len(self.val_set.dataset.class_mapping['regions'])
self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions)
Expand Down

0 comments on commit a97cf34

Please sign in to comment.