From cda9b181d8c44bd0525bfe0fbfcb1e534e2c1826 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 12:01:43 +0200 Subject: [PATCH] put costs on correct device --- kraken/lib/train/segmentation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index a6b5d2a0..b68fd79a 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -279,10 +279,11 @@ def validation_step(self, batch, batch_idx): costs = np.sort(costs)[:len(y_curves[line_cls][0])] penalty = np.full(diff, 8.0) costs = np.concatenate([costs, penalty]) - self.val_line_mean_dist.update(costs/8.0) + costs = costs/8.0 # no line output else: - self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0]))) + costs = torch.ones(len(y_curves[line_cls][0])) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: