Skip to content

Commit

Permalink
data types, fallback, ...
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 29, 2024
1 parent d53b4fa commit 30c41ec
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
6 changes: 4 additions & 2 deletions kraken/lib/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def transform(self, image, target):
continue
for line in lines:
# buffer out line to desired width
line = np.array([k for k, g in groupby(line)])
line = np.array([k for k, g in groupby(line)], dtype=np.float32)
shp_line = geom.LineString(line*scale)
split_offset = min(5, shp_line.length/2)
line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int)
Expand All @@ -224,7 +224,9 @@ def transform(self, image, target):
t[end_sep_cls, rr, cc] = 0
# Bézier curve fitting
if self.return_curves:
curves[key].append(to_curve(line, orig_size))
curves[key].append(to_curve(torch.from_numpy(line), orig_size))
for k, v in curves.items():
curves[k] = torch.stack(v)
for key, regions in target['regions'].items():
try:
cls_idx = self.class_mapping['regions'][key]
Expand Down
10 changes: 6 additions & 4 deletions kraken/lib/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@


_T_pil_or_np = TypeVar('_T_pil_or_np', Image.Image, np.ndarray)
_T_tensor_or_np = TypeVar('_T_tensor_or_np', torch.Tensor, np.ndarray)

logger = logging.getLogger('kraken')

Expand Down Expand Up @@ -1435,10 +1436,11 @@ def to_curve(baseline: torch.FloatTensor,
Returns:
Tensor of shape (8,)
"""
baseline = np.array(baseline)
if len(baseline) < min_points:
ls = LineString(baseline)
baseline = np.stack([np.array(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)])
curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))/im_size
baseline = torch.stack([torch.tensor(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)])
baseline = baseline.numpy()
curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))
curve = curve/im_size
curve = curve.flatten()
return torch.from_numpy(curve)
return torch.from_numpy(curve.astype(baseline.dtype))
17 changes: 10 additions & 7 deletions kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,17 @@ def validation_step(self, batch, batch_idx):
st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator']
end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator']


# vectorize and match lines
for line_cls, line_idx in self.nn.user_metadata['class_mapping']['lines'].items():
pred_curves = torch.stack([to_curve(pred_bl, pred.shape[:2][-1]) for pred_bl in vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...],
text_direction='horizontal')])
cost_curves = torch.cdist(pred_curves, y_curves[line_cls], p=1).view(len(pred_curves), -1).cpu()
row_ind, col_ind = linear_sum_assignment(cost_curves)
self.val_line_dist.update(cost_curves[row_ind, col_ind])
for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items():
pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal')
pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl]
if pred_curves:
pred_curves = torch.stack(pred_curves)
cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu()
row_ind, col_ind = linear_sum_assignment(cost_curves)
self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0)
else:
self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0])))

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
Expand Down

0 comments on commit 30c41ec

Please sign in to comment.