diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 48b3e30c4..5ce68c99b 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -283,9 +283,9 @@ def validation_step(self, batch, batch_idx): # num of predictions differs from target -> take n best # predictions and add error penalty term for the rest. if diff := abs(len(pred_curves) - len(target_curves)): - costs = np.sort(costs)[:len(target_curves)] - penalty = np.full(diff, 8.0) - costs = np.concatenate([costs, penalty]) + costs = torch.sort(costs)[:len(target_curves)] + penalty = torch.full((diff,), 8.0) + costs = torch.cat([costs, penalty]) costs = costs/8.0 # no line output else: