diff --git a/kraken/lib/dataset/recognition.py b/kraken/lib/dataset/recognition.py index 5f8744724..5a13d2f00 100644 --- a/kraken/lib/dataset/recognition.py +++ b/kraken/lib/dataset/recognition.py @@ -15,21 +15,24 @@ """ Utility functions for data loading and training of VGSL networks. """ -import dataclasses import io import json +import torch +import numpy as np +import pyarrow as pa import traceback +import dataclasses +import multiprocessing as mp + from collections import Counter from functools import partial from typing import (TYPE_CHECKING, Any, Callable, List, Literal, Optional, Tuple, Union) -import numpy as np -import pyarrow as pa -import torch from PIL import Image -from torch.utils.data import Dataset +from ctypes import c_char from torchvision import transforms +from torch.utils.data import Dataset from kraken.containers import BaselineLine, BBoxLine, Segmentation from kraken.lib import functional_im_transforms as F_t @@ -331,7 +334,7 @@ def __init__(self, if augmentation: self.aug = DefaultAugmenter() - self.im_mode = '1' + self._im_mode = mp.Value(c_char, b'1') def add(self, line: Optional[BaselineLine] = None, @@ -437,15 +440,16 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: legacy=self.legacy_polygons)) im = self.transforms(im) if im.shape[0] == 3: - im_mode = 'RGB' + im_mode = b'R' elif im.shape[0] == 1: - im_mode = 'L' + im_mode = b'L' if is_bitonal(im): - im_mode = '1' + im_mode = b'1' - if im_mode > self.im_mode: - logger.info(f'Upgrading "im_mode" from {self.im_mode} to {im_mode}') - self.im_mode = im_mode + with self._im_mode.get_lock(): + if im_mode > self._im_mode.value: + logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}') + self._im_mode.value = im_mode if self.aug: im = im.permute((1, 2, 0)).numpy() o = self.aug(image=im) @@ -461,6 +465,12 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: def __len__(self) -> int: return len(self._images) + @property + def im_mode(self): + return {b'1': '1', + b'L': 'L', + b'R': 'RGB'}[self._im_mode.value] + class GroundTruthDataset(Dataset): """ @@ -520,7 +530,7 @@ def __init__(self, if augmentation: self.aug = DefaultAugmenter() - self.im_mode = '1' + self._im_mode = mp.Value(c_char, b'1') def add(self, line: Optional[BBoxLine] = None, @@ -616,14 +626,15 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: im = im.crop((xmin, ymin, xmax, ymax)) im = self.transforms(im) if im.shape[0] == 3: - im_mode = 'RGB' + im_mode = b'R' elif im.shape[0] == 1: - im_mode = 'L' + im_mode = b'L' if is_bitonal(im): - im_mode = '1' - if im_mode > self.im_mode: - logger.info(f'Upgrading "im_mode" from {self.im_mode} to {im_mode}') - self.im_mode = im_mode + im_mode = b'1' + with self._im_mode.get_lock(): + if im_mode > self._im_mode.value: + logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}') + self._im_mode.value = im_mode if self.aug: im = im.permute((1, 2, 0)).numpy() o = self.aug(image=im) @@ -639,3 +650,11 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: def __len__(self) -> int: return len(self._images) + + @property + def im_mode(self): + return {b'1': '1', + b'L': 'L', + b'R': 'RGB'}[self._im_mode.value] + + diff --git a/kraken/lib/util.py b/kraken/lib/util.py index 609e924fe..267b4545b 100644 --- a/kraken/lib/util.py +++ b/kraken/lib/util.py @@ -56,7 +56,7 @@ def is_bitonal(im: Union[Image.Image, torch.Tensor]) -> bool: if isinstance(im, Image.Image): return im.getcolors(2) is not None and len(im.getcolors(2)) == 2 elif isinstance(im, torch.Tensor): - return len(im.int().unique()) == 2 + return len(im.unique()) == 2 def get_im_str(im: Image.Image) -> str: