Skip to content

Commit

Permalink
Fixes 8 bit image mode setting in datasets
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit 5524d88
Author: Benjamin Kiessling <[email protected]>
Date:   Mon Apr 8 20:09:28 2024 +0200

    Shared im_mode setting in GroundTruthDataset/PolygonGTDataset

commit 46be12c
Author: Benjamin Kiessling <[email protected]>
Date:   Sun Apr 7 23:33:45 2024 +0200

    put image mode in shared memory
  • Loading branch information
mittagessen committed Apr 8, 2024
1 parent d814500 commit 4c4e375
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
57 changes: 38 additions & 19 deletions kraken/lib/dataset/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@
"""
Utility functions for data loading and training of VGSL networks.
"""
import dataclasses
import io
import json
import torch
import numpy as np
import pyarrow as pa
import traceback
import dataclasses
import multiprocessing as mp

from collections import Counter
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, List, Literal, Optional,
Tuple, Union)

import numpy as np
import pyarrow as pa
import torch
from PIL import Image
from torch.utils.data import Dataset
from ctypes import c_char
from torchvision import transforms
from torch.utils.data import Dataset

from kraken.containers import BaselineLine, BBoxLine, Segmentation
from kraken.lib import functional_im_transforms as F_t
Expand Down Expand Up @@ -331,7 +334,7 @@ def __init__(self,
if augmentation:
self.aug = DefaultAugmenter()

self.im_mode = '1'
self._im_mode = mp.Value(c_char, b'1')

def add(self,
line: Optional[BaselineLine] = None,
Expand Down Expand Up @@ -437,15 +440,16 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
legacy=self.legacy_polygons))
im = self.transforms(im)
if im.shape[0] == 3:
im_mode = 'RGB'
im_mode = b'R'
elif im.shape[0] == 1:
im_mode = 'L'
im_mode = b'L'
if is_bitonal(im):
im_mode = '1'
im_mode = b'1'

if im_mode > self.im_mode:
logger.info(f'Upgrading "im_mode" from {self.im_mode} to {im_mode}')
self.im_mode = im_mode
with self._im_mode.get_lock():
if im_mode > self._im_mode.value:
logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}')
self._im_mode.value = im_mode
if self.aug:
im = im.permute((1, 2, 0)).numpy()
o = self.aug(image=im)
Expand All @@ -461,6 +465,12 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
def __len__(self) -> int:
return len(self._images)

@property
def im_mode(self):
return {b'1': '1',
b'L': 'L',
b'R': 'RGB'}[self._im_mode.value]


class GroundTruthDataset(Dataset):
"""
Expand Down Expand Up @@ -520,7 +530,7 @@ def __init__(self,
if augmentation:
self.aug = DefaultAugmenter()

self.im_mode = '1'
self._im_mode = mp.Value(c_char, b'1')

def add(self,
line: Optional[BBoxLine] = None,
Expand Down Expand Up @@ -616,14 +626,15 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
im = im.crop((xmin, ymin, xmax, ymax))
im = self.transforms(im)
if im.shape[0] == 3:
im_mode = 'RGB'
im_mode = b'R'
elif im.shape[0] == 1:
im_mode = 'L'
im_mode = b'L'
if is_bitonal(im):
im_mode = '1'
if im_mode > self.im_mode:
logger.info(f'Upgrading "im_mode" from {self.im_mode} to {im_mode}')
self.im_mode = im_mode
im_mode = b'1'
with self._im_mode.get_lock():
if im_mode > self._im_mode.value:
logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}')
self._im_mode.value = im_mode
if self.aug:
im = im.permute((1, 2, 0)).numpy()
o = self.aug(image=im)
Expand All @@ -639,3 +650,11 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:

def __len__(self) -> int:
return len(self._images)

@property
def im_mode(self):
return {b'1': '1',
b'L': 'L',
b'R': 'RGB'}[self._im_mode.value]


2 changes: 1 addition & 1 deletion kraken/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def is_bitonal(im: Union[Image.Image, torch.Tensor]) -> bool:
if isinstance(im, Image.Image):
return im.getcolors(2) is not None and len(im.getcolors(2)) == 2
elif isinstance(im, torch.Tensor):
return len(im.int().unique()) == 2
return len(im.unique()) == 2


def get_im_str(im: Image.Image) -> str:
Expand Down

0 comments on commit 4c4e375

Please sign in to comment.