From 30c41ec84a456e730377abc28bf84051413c9213 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 29 Sep 2024 04:04:50 +0200 Subject: [PATCH] data types, fallback, ... --- kraken/lib/dataset/segmentation.py | 6 ++++-- kraken/lib/segmentation.py | 10 ++++++---- kraken/lib/train/segmentation.py | 17 ++++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index 0a288d5a..3f89d346 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -202,7 +202,7 @@ def transform(self, image, target): continue for line in lines: # buffer out line to desired width - line = np.array([k for k, g in groupby(line)]) + line = np.array([k for k, g in groupby(line)], dtype=np.float32) shp_line = geom.LineString(line*scale) split_offset = min(5, shp_line.length/2) line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int) @@ -224,7 +224,9 @@ def transform(self, image, target): t[end_sep_cls, rr, cc] = 0 # Bézier curve fitting if self.return_curves: - curves[key].append(to_curve(line, orig_size)) + curves[key].append(to_curve(torch.from_numpy(line), orig_size)) + for k, v in curves.items(): + curves[k] = torch.stack(v) for key, regions in target['regions'].items(): try: cls_idx = self.class_mapping['regions'][key] diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index c929546e..69cc490d 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -52,6 +52,7 @@ _T_pil_or_np = TypeVar('_T_pil_or_np', Image.Image, np.ndarray) +_T_tensor_or_np = TypeVar('_T_tensor_or_np', torch.Tensor, np.ndarray) logger = logging.getLogger('kraken') @@ -1435,10 +1436,11 @@ def to_curve(baseline: torch.FloatTensor, Returns: Tensor of shape (8,) """ - baseline = np.array(baseline) if len(baseline) < min_points: ls = LineString(baseline) - baseline = np.stack([np.array(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)]) - curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))/im_size + baseline = torch.stack([torch.tensor(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)]) + baseline = baseline.numpy() + curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]])) + curve = curve/im_size curve = curve.flatten() - return torch.from_numpy(curve) + return torch.from_numpy(curve.astype(baseline.dtype)) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 7cd1f403..08c1692c 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -264,14 +264,17 @@ def validation_step(self, batch, batch_idx): st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] - # vectorize and match lines - for line_cls, line_idx in self.nn.user_metadata['class_mapping']['lines'].items(): - pred_curves = torch.stack([to_curve(pred_bl, pred.shape[:2][-1]) for pred_bl in vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], - text_direction='horizontal')]) - cost_curves = torch.cdist(pred_curves, y_curves[line_cls], p=1).view(len(pred_curves), -1).cpu() - row_ind, col_ind = linear_sum_assignment(cost_curves) - self.val_line_dist.update(cost_curves[row_ind, col_ind]) + for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): + pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') + pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] + if pred_curves: + pred_curves = torch.stack(pred_curves) + cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() + row_ind, col_ind = linear_sum_assignment(cost_curves) + self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0) + else: + self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0]))) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: