From 2fb846a45189a490ee0b07d7170f9c03038c5409 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 9 May 2024 17:10:16 +0200 Subject: [PATCH] don't update metrics after sanity checking --- kraken/lib/pretrain/model.py | 17 +++++----- kraken/lib/ro/model.py | 24 +++++++------- kraken/lib/train.py | 63 +++++++++++++++++++----------------- 3 files changed, 55 insertions(+), 49 deletions(-) diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index eb0d3cc57..97b220751 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -394,16 +394,17 @@ def validation_step(self, batch, batch_idx): self.log('CE', loss, on_step=True, on_epoch=True) def on_validation_epoch_end(self): - ce = np.mean(self.val_ce) + if not self.trainer.sanity_checking: + ce = np.mean(self.val_ce) + + if ce < self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {ce} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = ce + logger.info(f'validation run: cross_enctropy: {ce}') + self.log('val_ce', ce, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.val_ce.clear() - if ce < self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {ce} ({self.current_epoch})') - self.best_epoch = self.current_epoch - self.best_metric = ce - logger.info(f'validation run: cross_enctropy: {ce}') - self.log('val_ce', ce, on_step=False, on_epoch=True, prog_bar=True, logger=True) - def training_step(self, batch, batch_idx): o = self._step(batch, batch_idx) if o is not None: diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index c9c661afa..d5ef0927c 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -164,20 +164,22 @@ def validation_step(self, batch, batch_idx): self.val_spearman.append(spearman_dist.cpu()) def on_validation_epoch_end(self): - val_metric = np.mean(self.val_spearman) - val_loss = np.mean(self.val_losses) + if not self.trainer.sanity_checking: + val_metric = np.mean(self.val_spearman) + val_loss = np.mean(self.val_losses) + + if val_metric < self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = val_metric + logger.info(f'validation run: val_spearman {val_metric} val_loss {val_loss}') + self.log('val_spearman', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.val_spearman.clear() self.val_losses.clear() - if val_metric < self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})') - self.best_epoch = self.current_epoch - self.best_metric = val_metric - logger.info(f'validation run: val_spearman {val_metric} val_loss {val_loss}') - self.log('val_spearman', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) - def training_step(self, batch, batch_idx): x, y = batch['sample'], batch['target'] logits = self.ro_net(x) diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 48e330afd..1e5a41e80 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -515,17 +515,19 @@ def validation_step(self, batch, batch_idx): self.global_step) def on_validation_epoch_end(self): - accuracy = 1.0 - self.val_cer.compute() - word_accuracy = 1.0 - self.val_wer.compute() - - if accuracy > self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {accuracy} ({self.current_epoch})') - self.best_epoch = self.current_epoch - self.best_metric = accuracy - logger.info(f'validation run: total chars {self.val_cer.total} errors {self.val_cer.errors} accuracy {accuracy}') - self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_word_accuracy', word_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) + if not self.trainer.sanity_checking: + accuracy = 1.0 - self.val_cer.compute() + word_accuracy = 1.0 - self.val_wer.compute() + + if accuracy > self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {accuracy} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = accuracy + logger.info(f'validation run: total chars {self.val_cer.total} errors {self.val_cer.errors} accuracy {accuracy}') + self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_word_accuracy', word_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) + # reset metrics even if not sanity checking self.val_cer.reset() self.val_wer.reset() @@ -905,25 +907,26 @@ def validation_step(self, batch, batch_idx): self.val_freq_iu.update(pred, y) def on_validation_epoch_end(self): - - pixel_accuracy = self.val_px_accuracy.compute() - mean_accuracy = self.val_mean_accuracy.compute() - mean_iu = self.val_mean_iu.compute() - freq_iu = self.val_freq_iu.compute() - - if mean_iu > self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') - self.best_epoch = self.current_epoch - self.best_metric = mean_iu - - logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}') - - self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) - + if not self.trainer.sanity_checking: + pixel_accuracy = self.val_px_accuracy.compute() + mean_accuracy = self.val_mean_accuracy.compute() + mean_iu = self.val_mean_iu.compute() + freq_iu = self.val_freq_iu.compute() + + if mean_iu > self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = mean_iu + + logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}') + + self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) + + # reset metrics even if sanity checking self.val_px_accuracy.reset() self.val_mean_accuracy.reset() self.val_mean_iu.reset()