diff --git a/README.md b/README.md index d8f0f59..e20d3bf 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Recognition and doing inference with them. ## Feature highlights -* [DeepFont-like network architecture](https://arxiv.org/pdf/1507.03196v1.pdf) +* DeepFont-like network architecture. See [​​Z. Wang, J. Yang, H. Jin, E. Shechtman, A. Agarwala, J. Brandt and T. Huang, “DeepFont: Identify Your Font from An Image”, In Proceedings of ACM International Conference on Multimedia (ACM MM) , 2015](https://arxiv.org/abs/1507.03196) * Configuration-based synthetic dataset generation * Configuration-based model training via [PyTorch Lightning](https://lightning.ai/pytorch-lightning) * Supports training and inference on Linux, MacOS and Windows. @@ -47,7 +47,7 @@ make test ### Generating a synthetic dataset If needed, the model can be trained on synthetic data. `fontina` provides a synthetic -dataset generator that follows part of the recommendations from the [DeepFont paper](https://arxiv.org/pdf/1507.03196v1.pdf) +dataset generator that follows part of the recommendations from the [DeepFont paper](https://arxiv.org/abs/1507.03196) to make the synthetic data look closer to the real data. To use the generator: 1. Make a copy of `configs/sample.yaml`, e.g. `configs/mymodel.yaml` @@ -75,7 +75,7 @@ fonts: 3. Run the generation: ```bash -python src/fontina/generate.py -c configs/mymodel.yaml -o outputs/font-images/mymodel +fontina-generate -c configs/mymodel.yaml -o outputs/font-images/mymodel ``` After this completes, there should be one directory per configured font in `outputs/font-images/mymodel`. @@ -124,7 +124,7 @@ python src/fontina/train.py -c configs/mymodel.yaml ``` #### Part 2 - Supervised training -1. Open `configs/mymodel.yaml` and tweak the `training` section: +1. Open `configs/mymodel.yaml` (or create a new one!) and tweak the `training` section: ```yaml training: @@ -155,7 +155,7 @@ training: 2. Then run the training with: ```bash -python src/fontina/train.py -c configs/mymodel.yaml +fontina-train -c configs/mymodel.yaml ``` ### **(Optional)** - Monitor performance using TensorBoard @@ -171,5 +171,31 @@ tensorboard --logdir=lightning_logs Once training is complete, the resulting model can be used to run inference. ```bash -python src/fontina/predict.py -w "outputs/models/mymodel-full/best_checkpoint.ckpt" -i "assets/images/test.png" +fontina-predict -n 6 -w "outputs/models/mymodel-full/best_checkpoint.ckpt" -i "assets/images/test.png" ``` + +## AdobeVFR Pre-trained model +The AdobeVFR dataset is currently available for download [at Dropbox, here](https://www.dropbox.com/sh/o320sowg790cxpe/AADDmdwQ08GbciWnaC20oAmna?dl=0). The license for using and distributing the dataset is available [here](https://www.dropbox.com/sh/o320sowg790cxpe/AADDmdwQ08GbciWnaC20oAmna?dl=0&preview=license.txt), which cites: + +> This dataset ('Licensed Material') is made available to the scientific community for non-commercial research purposes such as academic research, teaching, scientific publications or personal experimentation. + +The model, being trained on that dataset, retains the same spirit and the same license applies: the release model can only be used for non-commercial purposes. + +### How to train + +1. Download the dataset to `assets/AdobeVFR` +2. Unpack `assets/AdobeVFR/Raw Image/VFR_real_u/scrape-wtf-new.zip` in that directory so that the `assets/AdobeVFR/Raw Image/VFR_real_u/scrape-wtf-new/` path exists +3. Run `fontina-train -c configs/adobe-vfr-autoencoder.yaml`. This will take a long while but progress can be checked with Tensorboard (see the previous sections) during training +4. Change `configs/adobe-vfr.yaml` so that `scae_checkpoint_file` points to the best checkpoint from step (3). +5. Run `fontina-train -c configs/adobe-vfr.yaml`. This will take a long while (but less than the unsupervised training round) + +### Downloading the models +While only the full model is needed, the stand-alone autonencoder model is being released as well. + +* Stand-alone autoencoder model: [Google Drive](https://drive.google.com/file/d/107Ontyg2FGxOKvhE7KM7HSaJ1Wn2Merr/view?usp=sharing) +* Full model: [Google Drive](https://drive.google.com/file/d/1Fw-bjmapCXe0aCiYvOyGLmYocZDvmptK/view?usp=drive_link) + +> **Note** +The pre-trained model achieves a validation loss of 0.3523, with an accuracy of 0.8855 after 14 epochs. +Unfortunately the test performance on `VFR_real_test` is much worse, with a top-1 accuracy of 0.05. +I'm releasing the model in the hope that somebody could help me fixing this 😊😅 diff --git a/configs/adobe-vfr-autoencoder.yaml b/configs/adobe-vfr-autoencoder.yaml new file mode 100644 index 0000000..f535ccb --- /dev/null +++ b/configs/adobe-vfr-autoencoder.yaml @@ -0,0 +1,87 @@ +--- +# This section of the configuration is used to control +# the generation of the synthetic image data for the +# visual font recognition task. +fonts: + # Whether or not to enable random spacing between characters. + random_character_spacing: False + + # The regular expression to use to generate the text + # in the synthetic image samples. + regex_template: '[A-Z0-9]{5,10} [A-Z0-9]{3,7}' + + # The path to the directory containing background images. + # If provided, images in this directory will be used as + # background for the generated text. If omitted, images + # will have a white background. + backgrounds_path: "assets/backgrounds" + + # The number of samples to generate for each font. + samples_per_font: 50 + + classes: + - name: Test Font + path: "assets/fonts/test/Test.ttf" + - name: Other Test Font + path: "assets/fonts/test2/Test2.ttf" + +# This section controls the training configuration for the model. +training: + only_autoencoder: True + + # The path to the pre-trained checkpoint to use for the + # stacked autoencoders within the DeepFont-like model. Setting + # this property skip training the SCAE. + # scae_checkpoint_file: "outputs/adobevfr/final/autoenc-epoch=13-val_loss=0.0016.ckpt" + + # Whether or not to use a fixed random seed for training. Note + # that this is useful for creating reproducible runs for debugging + # purposes. + # fixed_seed: 42 + + # The type of data source stored in the data root. + # It's one of: + # - "raw-images": the data root contains one directory + # per font type, each having the samples coming from + # that font. + # - "adobevfr": the data root contains the AdobeVFR in + # BCF format, i.e. the 'VFR_real_test', 'VFR_syn_train' + # and 'VFR_syn_val' directories. + dataset_type: "adobevfr" + + # The root directory containing the data generated from the + # synthetic image generation step. + data_root: "assets/AdobeVFR" + + # The directory that will contain the model checkpoints. + output_dir: "outputs/adobevfr/autoenc" + + # The number of workers to use for the data loaders. See + # the PyTorch documentation here: + # https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader + num_workers: 12 + + # The size of the batch to use for training. + batch_size: 128 + + # The initial learning rate to use for training. + learning_rate: 0.01 + + epochs: 10 + + # The ratio to use for splitting the samples in the data + # root into train, validation and test sets. + # Note that the validation set is used during for validating + # during the training cycle, while the testing set, if + # provided, is used after the training phase is complete. + train_ratio: 0.8 + # The following ratios are meaningful only if run_test_cycle + # is enabled. + validation_ratio: 0.1 + test_ratio: 0.1 + + # Whether or not to use a fraction of the data to run a + # test cycle on the trained model. If this is disabled + # then only the train ratio will be used: the validation + # ratio will be automatically computed. + run_test_cycle: True diff --git a/configs/adobe-vfr.yaml b/configs/adobe-vfr.yaml new file mode 100644 index 0000000..f3016eb --- /dev/null +++ b/configs/adobe-vfr.yaml @@ -0,0 +1,88 @@ +--- +# This section of the configuration is used to control +# the generation of the synthetic image data for the +# visual font recognition task. +fonts: + # Whether or not to enable random spacing between characters. + random_character_spacing: False + + # The regular expression to use to generate the text + # in the synthetic image samples. + regex_template: '[A-Z0-9]{5,10} [A-Z0-9]{3,7}' + + # The path to the directory containing background images. + # If provided, images in this directory will be used as + # background for the generated text. If omitted, images + # will have a white background. + backgrounds_path: "assets/backgrounds" + + # The number of samples to generate for each font. + samples_per_font: 50 + + classes: + - name: Test Font + path: "assets/fonts/test/Test.ttf" + - name: Other Test Font + path: "assets/fonts/test2/Test2.ttf" + +# This section controls the training configuration for the model. +training: + # TODO: When training the autoencoder, use the real images. + only_autoencoder: False + + # The path to the pre-trained checkpoint to use for the + # stacked autoencoders within the DeepFont-like model. Setting + # this property skip training the SCAE. + scae_checkpoint_file: "outputs/adobevfr/final/v82-autoenc-epoch=10-val_loss=0.0019-val_accuracy=0.0000.ckpt" + + # Whether or not to use a fixed random seed for training. Note + # that this is useful for creating reproducible runs for debugging + # purposes. + # fixed_seed: 42 + + # The type of data source stored in the data root. + # It's one of: + # - "raw-images": the data root contains one directory + # per font type, each having the samples coming from + # that font. + # - "adobevfr": the data root contains the AdobeVFR in + # BCF format, i.e. the 'VFR_real_test', 'VFR_syn_train' + # and 'VFR_syn_val' directories. + dataset_type: "adobevfr" + + # The root directory containing the data generated from the + # synthetic image generation step. + data_root: "assets/AdobeVFR" + + # The directory that will contain the model checkpoints. + output_dir: "outputs/adobevfr/full" + + # The number of workers to use for the data loaders. See + # the PyTorch documentation here: + # https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader + num_workers: 12 + + # The size of the batch to use for training. + batch_size: 128 + + # The initial learning rate to use for training. + learning_rate: 0.01 + + epochs: 20 + + # The ratio to use for splitting the samples in the data + # root into train, validation and test sets. + # Note that the validation set is used during for validating + # during the training cycle, while the testing set, if + # provided, is used after the training phase is complete. + train_ratio: 0.8 + # The following ratios are meaningful only if run_test_cycle + # is enabled. + validation_ratio: 0.1 + test_ratio: 0.1 + + # Whether or not to use a fraction of the data to run a + # test cycle on the trained model. If this is disabled + # then only the train ratio will be used: the validation + # ratio will be automatically computed. + run_test_cycle: True diff --git a/pyproject.toml b/pyproject.toml index 066f16a..bce624b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ readme = "README.md" license = { text = "MIT" } [project.scripts] +fontina-generate = "fontina.generate:main" fontina-train = "fontina.train:main" fontina-predict = "fontina.predict:main" diff --git a/src/fontina/adobevfr_dataset.py b/src/fontina/adobevfr_dataset.py index 2965940..f875623 100644 --- a/src/fontina/adobevfr_dataset.py +++ b/src/fontina/adobevfr_dataset.py @@ -1,8 +1,8 @@ -import numpy as np +import cv2 import io +import numpy as np import torch -from PIL import Image from torch.utils.data import Dataset @@ -26,12 +26,15 @@ def __init__(self, bcf_path: str, dataset_type: str, transform=None): def __getitem__(self, index): binary_image = self._get_bcf_entry_by_index(index) - pil_image = Image.open(io.BytesIO(binary_image)).convert("L") - raw_image = np.array(pil_image, dtype="uint8") + image_as_array = np.asarray( + bytearray(io.BytesIO(binary_image).read()), dtype=np.uint8 + ) + cv2image = cv2.imdecode(image_as_array, cv2.IMREAD_GRAYSCALE) + raw_image = np.array(cv2image, dtype="uint8") x = self.transform(image=raw_image)["image"] if self.transform else raw_image # We need to cast to `torch.long` to prevent errors such as # "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int'. - return x, torch.tensor(self.labels[index], dtype=torch.long) + return x, torch.as_tensor(self.labels[index], dtype=torch.long) def __len__(self): return len(self._bcf_offsets) - 1 diff --git a/src/fontina/augmentation_utils.py b/src/fontina/augmentation_utils.py index 886d53d..5d6e1bf 100644 --- a/src/fontina/augmentation_utils.py +++ b/src/fontina/augmentation_utils.py @@ -1,18 +1,18 @@ import albumentations as A +import cv2 import numpy as np import numpy.typing as npt from albumentations.pytorch import ToTensorV2 from albumentations.core.transforms_interface import ImageOnlyTransform -from PIL import Image -def resize_fixed_height(img: Image.Image, new_height: int = 105): +def resize_fixed_height(img: npt.NDArray[np.uint8], new_height: int = 105): # From the paper: height is fixed to 105 pixels, width is scaled # to keep aspect ratio. - width, height = img.size + height, width = img.shape[:2] new_width = round(new_height * width / height) - return img.resize((new_width, new_height), Image.LANCZOS) + return cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) def split_patches_np(img: npt.NDArray[np.uint8], step: int, drop_last: bool): @@ -23,7 +23,7 @@ def split_patches_np(img: npt.NDArray[np.uint8], step: int, drop_last: bool): patches.append(img[0:height, x : x + step]) # Fixup the last patch instead of dropping it because - # its width is smaller than 105. When cropping with PIL and + # its width is smaller than 105. When cropping and # the patch is smaller than the needed area, it gets filled # with black pixels. We should recolor them instead of discarding. available_width = width % step @@ -64,7 +64,11 @@ def apply(self, img, **params): _, width = img.shape if width <= 105: - return img + return np.append( + img, + np.full((105, 105 - width), 255, dtype="uint8"), + axis=1, + ) if not self.constrained_patches: start_x = np.random.randint(0, width - 105) @@ -89,9 +93,7 @@ def apply(self, img, **params): height, width = img.shape ratio = np.random.uniform(low=self.ratio_range[0], high=self.ratio_range[1]) new_width = round(width * ratio) - squeezed = Image.fromarray(img).resize( - (new_width, height), Image.Resampling.LANCZOS - ) + squeezed = cv2.resize(img, (new_width, height), cv2.INTER_LANCZOS4) return np.array(squeezed) def get_transform_init_args_names(self): @@ -113,9 +115,7 @@ def __init__(self, squeeze_ratio, always_apply=False, p=1.0) -> None: def apply(self, img, **params): height, width = img.shape new_width = round(height * self.squeeze_ratio) - squeezed = Image.fromarray(img).resize( - (new_width, height), Image.Resampling.LANCZOS - ) + squeezed = cv2.resize(img, (new_width, height), cv2.INTER_LANCZOS4) return np.array(squeezed, dtype="uint8") def get_transform_init_args_names(self): @@ -132,7 +132,7 @@ def __init__(self, target_height: int, always_apply=False, p=1.0) -> None: self.target_height = target_height def apply(self, img, **params): - resized = resize_fixed_height(Image.fromarray(img), self.target_height) + resized = resize_fixed_height(img, self.target_height) return np.array(resized) @@ -198,10 +198,10 @@ def get_random_square_patch_augmentation() -> A.Compose: ) -def get_test_augmentations(r: float) -> A.Compose: +def get_test_augmentations(squeeze_ratio: float) -> A.Compose: return A.Sequential( [ ResizeHeight(target_height=105, always_apply=True), - Squeezing(squeeze_ratio=r, always_apply=True), + Squeezing(squeeze_ratio=squeeze_ratio, always_apply=True), ] ) diff --git a/src/fontina/augmented_dataset.py b/src/fontina/augmented_dataset.py index 74e0787..a7bdd6d 100644 --- a/src/fontina/augmented_dataset.py +++ b/src/fontina/augmented_dataset.py @@ -1,4 +1,5 @@ import numpy as np +import torch from torch.utils.data import Dataset @@ -11,9 +12,13 @@ def __init__(self, dataset, num_classes, transform=None): self.transform = transform def __getitem__(self, index): - raw_image = np.asarray(self.dataset[index][0]) - x = self.transform(image=raw_image)["image"] if self.transform else raw_image - return x, self.dataset[index][1] + raw_image = self.dataset[index][0] + x = ( + self.transform(image=np.asarray(raw_image))["image"] + if self.transform + else raw_image + ) + return x, torch.as_tensor(self.dataset[index][1], dtype=torch.long) def __len__(self): return len(self.dataset) diff --git a/src/fontina/predict.py b/src/fontina/predict.py index 27ce71b..44de2f0 100644 --- a/src/fontina/predict.py +++ b/src/fontina/predict.py @@ -15,7 +15,7 @@ def get_parser(): - parser = argparse.ArgumentParser(description="Fontina detect") + parser = argparse.ArgumentParser(description="Fontina predict") parser.add_argument( "-w", "--weights", @@ -43,7 +43,9 @@ def get_parser(): def predict(model: DeepFontWrapper, img) -> torch.Tensor: all_soft_preds = [] for _ in range(3): - enhancement_pipeline = get_test_augmentations(r=1.5 + np.random.rand() * 2) + enhancement_pipeline = get_test_augmentations( + squeeze_ratio=np.random.uniform(low=1.5, high=3.5) + ) enhanced_img = enhancement_pipeline(image=np.asarray(img))["image"] patch_sampler = get_random_square_patch() @@ -67,7 +69,7 @@ def main(): ) raw_img = PIL.Image.open(args.input).convert("L") - img = resize_fixed_height(raw_img) + img = resize_fixed_height(np.asarray(raw_img)) predicted_class = predict(model, img) diff --git a/src/fontina/train.py b/src/fontina/train.py index 570fbd0..7ce54bd 100644 --- a/src/fontina/train.py +++ b/src/fontina/train.py @@ -8,8 +8,6 @@ ) import torch -from torchvision import datasets, transforms -from torch.utils.data import DataLoader from fontina.adobevfr_dataset import AdobeVFRDataset from fontina.augmentation_utils import ( get_deepfont_full_augmentations, @@ -22,6 +20,18 @@ from fontina.models.lightning_generate_callback import GenerateCallback from fontina.models.lightning_wrappers import DeepFontAutoencoderWrapper, DeepFontWrapper +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + +# The torchvision `DataLoader` uses Pillow under the hood to +# load image files. However PIL will fail to load some PNG +# files by throwing a zlib decompression error: +# "ValueError: Decompressed Data Too Large". Setting +# `LOAD_TRUNCATED_IMAGES = True` mitigates this problem. +from PIL import Image, ImageFile, UnidentifiedImageError + +ImageFile.LOAD_TRUNCATED_IMAGES = True + def get_parser(): parser = argparse.ArgumentParser(description="Fontina training") @@ -90,31 +100,61 @@ def load_and_split_data(train_config): def load_adobevfr_dataset(train_config): all_train_data = AdobeVFRDataset( - f"{train_config['data_root']}/VFR_syn_train", + f"{train_config['data_root']}/BCF format/VFR_syn_train", "train", get_deepfont_full_augmentations(), ) + num_labels = all_train_data.num_labels + + # From the DeepFont paper: "We first train the SCAE on both synthetic and + # real-world data in a unsupervised way [...]". When training the autoencoder, + # we merge the real and synthetic datasets for training purposes. + if train_config["only_autoencoder"]: + + def is_valid_image(path): + try: + _ = Image.open(path) + return True + except (UnidentifiedImageError, ValueError): + print(f"Failed to load image: {path}") + return False + + real_data = datasets.ImageFolder( + root=f"{train_config['data_root']}/Raw Image/VFR_real_u", + # Important: albumentation can't set grayscale and output only one + # channel, so do it here. + transform=transforms.Grayscale(num_output_channels=1), + target_transform=None, + is_valid_file=is_valid_image, + ) + real_data_processed = AugmentedDataset( + real_data, 1, get_deepfont_full_augmentations() + ) + + # Override 'all_train_data' with the joined dataset. + all_train_data = torch.utils.data.ConcatDataset( + [all_train_data, real_data_processed] + ) + # Although the AdobeVFR dataset readme says that VFR_syn_val contains # the validation for the same classes as VFR_syn_train, that doens't # seem to be the case: the former contains 2383 classes, the latter # 4383. Instead of using it, let's split the rain set. splits = torch.utils.data.random_split(all_train_data, [0.95, 0.05]) - train_set_processed = AugmentedDataset(splits[0], all_train_data.num_labels, None) - validation_set_processed = AugmentedDataset( - splits[1], all_train_data.num_labels, None - ) + train_set_processed = AugmentedDataset(splits[0], num_labels, None) + validation_set_processed = AugmentedDataset(splits[1], num_labels, None) """ validation_set_processed = AdobeVFRDataset( - f"{train_config['data_root']}/VFR_syn_val", + f"{train_config['data_root']}/BCF format/VFR_syn_val", "val", get_random_square_patch_augmentation(), ) """ test_set_processed = ( AdobeVFRDataset( - f"{train_config['data_root']}/VFR_real_test", + f"{train_config['data_root']}/BCF format/VFR_real_test", "vfr_large", get_random_square_patch_augmentation(), ) @@ -123,7 +163,7 @@ def load_adobevfr_dataset(train_config): ) return ( - all_train_data.num_labels, + num_labels, train_set_processed, validation_set_processed, test_set_processed, diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index d3a35a4..5ef7798 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -1,18 +1,17 @@ import fontina.augmentation_utils as au import numpy as np -import PIL def test_resize_fixed_height(): - test_img = PIL.Image.fromarray(np.ones((70, 100), dtype=np.uint8)) + test_img = np.ones((70, 100), dtype=np.uint8) resized = au.resize_fixed_height(test_img, new_height=105) # Check for the expected height. - assert resized.size[1] == 105 + assert resized.shape[0] == 105 # Check that the ratio of the source image is kept. assert np.isclose( - test_img.size[0] / test_img.size[1], resized.size[0] / resized.size[1] + test_img.shape[0] / test_img.shape[1], resized.shape[0] / resized.shape[1] ) @@ -58,6 +57,30 @@ def test_pick_random_patch_constrained(): aug = au.PickRandomPatch(constrained_patches=True, always_apply=True) result = aug(image=test_img)["image"] + assert np.all(np.equal(result.shape, [105, 105])) + # It's either a full black or a full white patch, as we're constraining # the random choice into non-random patches. assert np.all(result == 255) or np.all(result == 0) + + +def test_pick_random_small_width(): + # Craft an image with one patch: its width being less than 105. + test_img = np.ones((105, 80), dtype=np.uint8) * 255 + + aug = au.PickRandomPatch(constrained_patches=True, always_apply=True) + result = aug(image=test_img)["image"] + + # The returned patch must still be 105x105. + assert np.all(np.equal(result.shape, [105, 105])) + + +def test_pick_random_exact_width(): + # Craft an image with one patch: its width being exactly 105. + test_img = np.ones((105, 105), dtype=np.uint8) * 255 + + aug = au.PickRandomPatch(constrained_patches=True, always_apply=True) + result = aug(image=test_img)["image"] + + # The returned patch must still be 105x105. + assert np.all(np.equal(result.shape, [105, 105]))