diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 28a879ef7..6a892faef 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -458,7 +458,14 @@ def validation_step(self, batch, batch_idx): decoded_targets.append(''.join([x[0] for x in self.val_codec.decode([(x, 0, 0, 0) for x in batch['target'][idx:idx+offset]])])) idx += offset self.val_cer.update(pred, decoded_targets) - + + if self.logger and self.trainer.state.stage != 'sanity_check' and self.hparams.batch_size * batch_idx < 16: + for i in range(self.hparams.batch_size): + count = self.hparams.batch_size * batch_idx + i + if count < 16: + self.logger.experiment.add_image(f'Validation #{count}, target: {decoded_targets[i]}', batch['image'][i], self.global_step, dataformats="CHW") + self.logger.experiment.add_text(f'Validation #{count}, target: {decoded_targets[i]}', pred[i], self.global_step) + def on_validation_epoch_end(self): self.val_cer.compute() accuracy = 1.0 - self.val_cer.compute() @@ -476,6 +483,16 @@ def on_validation_epoch_end(self): def setup(self, stage: Optional[str] = None): # finalize models in case of appending/loading if stage in [None, 'fit']: + + # Log a few sample images before the datasets are encoded. + # This is only possible for Arrow datasets, because the + # other dataset types can only be accessed after encoding + if self.logger and isinstance(self.train_set.dataset, ArrowIPCRecognitionDataset) : + for i in range(min(len(self.train_set), 16)): + idx = np.random.randint(len(self.train_set)) + sample = self.train_set[idx] + self.logger.experiment.add_image(f'train_set sample #{i}: {sample["target"]}', sample['image']) + if self.append: self.train_set.dataset.encode(self.codec) # now we can create a new model