diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index dcb9d8764..a6b5d2a0e 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -267,7 +267,7 @@ def validation_step(self, batch, batch_idx): # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): - pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') + pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal') pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] if pred_curves: pred_curves = torch.stack(pred_curves)