Skip to content

Commit

Permalink
don't update metrics after sanity checking
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed May 9, 2024
1 parent ac5e7b5 commit 2fb846a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 49 deletions.
17 changes: 9 additions & 8 deletions kraken/lib/pretrain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,16 +394,17 @@ def validation_step(self, batch, batch_idx):
self.log('CE', loss, on_step=True, on_epoch=True)

def on_validation_epoch_end(self):
ce = np.mean(self.val_ce)
if not self.trainer.sanity_checking:
ce = np.mean(self.val_ce)

if ce < self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {ce} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = ce
logger.info(f'validation run: cross_enctropy: {ce}')
self.log('val_ce', ce, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.val_ce.clear()

if ce < self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {ce} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = ce
logger.info(f'validation run: cross_enctropy: {ce}')
self.log('val_ce', ce, on_step=False, on_epoch=True, prog_bar=True, logger=True)

def training_step(self, batch, batch_idx):
o = self._step(batch, batch_idx)
if o is not None:
Expand Down
24 changes: 13 additions & 11 deletions kraken/lib/ro/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,22 @@ def validation_step(self, batch, batch_idx):
self.val_spearman.append(spearman_dist.cpu())

def on_validation_epoch_end(self):
val_metric = np.mean(self.val_spearman)
val_loss = np.mean(self.val_losses)
if not self.trainer.sanity_checking:
val_metric = np.mean(self.val_spearman)
val_loss = np.mean(self.val_losses)

if val_metric < self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = val_metric
logger.info(f'validation run: val_spearman {val_metric} val_loss {val_loss}')
self.log('val_spearman', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

self.val_spearman.clear()
self.val_losses.clear()

if val_metric < self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = val_metric
logger.info(f'validation run: val_spearman {val_metric} val_loss {val_loss}')
self.log('val_spearman', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

def training_step(self, batch, batch_idx):
x, y = batch['sample'], batch['target']
logits = self.ro_net(x)
Expand Down
63 changes: 33 additions & 30 deletions kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,17 +515,19 @@ def validation_step(self, batch, batch_idx):
self.global_step)

def on_validation_epoch_end(self):
accuracy = 1.0 - self.val_cer.compute()
word_accuracy = 1.0 - self.val_wer.compute()

if accuracy > self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {accuracy} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = accuracy
logger.info(f'validation run: total chars {self.val_cer.total} errors {self.val_cer.errors} accuracy {accuracy}')
self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_word_accuracy', word_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
if not self.trainer.sanity_checking:
accuracy = 1.0 - self.val_cer.compute()
word_accuracy = 1.0 - self.val_wer.compute()

if accuracy > self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {accuracy} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = accuracy
logger.info(f'validation run: total chars {self.val_cer.total} errors {self.val_cer.errors} accuracy {accuracy}')
self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_word_accuracy', word_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
# reset metrics even if not sanity checking
self.val_cer.reset()
self.val_wer.reset()

Expand Down Expand Up @@ -905,25 +907,26 @@ def validation_step(self, batch, batch_idx):
self.val_freq_iu.update(pred, y)

def on_validation_epoch_end(self):

pixel_accuracy = self.val_px_accuracy.compute()
mean_accuracy = self.val_mean_accuracy.compute()
mean_iu = self.val_mean_iu.compute()
freq_iu = self.val_freq_iu.compute()

if mean_iu > self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = mean_iu

logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}')

self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True)

if not self.trainer.sanity_checking:
pixel_accuracy = self.val_px_accuracy.compute()
mean_accuracy = self.val_mean_accuracy.compute()
mean_iu = self.val_mean_iu.compute()
freq_iu = self.val_freq_iu.compute()

if mean_iu > self.best_metric:
logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})')
self.best_epoch = self.current_epoch
self.best_metric = mean_iu

logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}')

self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True)

# reset metrics even if sanity checking
self.val_px_accuracy.reset()
self.val_mean_accuracy.reset()
self.val_mean_iu.reset()
Expand Down

0 comments on commit 2fb846a

Please sign in to comment.