Skip to content

Commit

Permalink
Migrate AutoAlbument core to PyTorch Lightning (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz authored Feb 21, 2021
1 parent 357838d commit 784fb71
Show file tree
Hide file tree
Showing 102 changed files with 1,536 additions and 1,858 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install dependencies
run: pip install .[tests]
- name: Install linters
run: pip install pydocstyle flake8 flake8-docstrings mypy
run: pip install flake8==3.8.4 flake8-docstrings==1.5.0 mypy==0.800
- name: Run PyTest
run: pytest
- name: Run Flake8
Expand All @@ -42,7 +42,7 @@ jobs:
run: mypy autoalbument

check_code_formatting:
name: Check code formatting with Black
name: Check code formatting with isort and Black
runs-on: ubuntu-latest
strategy:
matrix:
Expand All @@ -56,7 +56,9 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install Black
run: pip install black==20.8b1
- name: Install isort and Black
run: pip install isort==5.7.0 black==20.8b1
- name: Run isort
run: isort --check-only .
- name: Run Black
run: black --config=black.toml --check .
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down
26 changes: 17 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
args: [--config=black.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v3.4.0
hooks:
- id: flake8
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
hooks:
- id: flake8
additional_dependencies: [flake8-docstrings==1.5.0]
- repo: https://github.com/pycqa/isort
rev: 5.7.0
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
args: [--config=black.toml]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
rev: v0.800
hooks:
- id: mypy
files: ^autoalbument/
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include LICENSE
include README.md

graft tests
graft autoalbument/cli/templates
graft autoalbument/cli/resources
graft autoalbument/cli/conf

global-exclude *.py[co] .DS_Store
2 changes: 1 addition & 1 deletion autoalbument/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.1"
__version__ = "0.5.0"
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import random

import torch
import torch.nn.functional as F

from autoalbument.faster_autoaugment.albumentations_pytorch.affine import (
get_scaling_matrix,
from autoalbument.albumentations_pytorch.affine import (
get_rotation_matrix,
get_scaling_matrix,
warp_affine,
)
from autoalbument.faster_autoaugment.albumentations_pytorch.utils import MAX_VALUES_BY_DTYPE, TorchPadding, clipped
from autoalbument.albumentations_pytorch.utils import (
MAX_VALUES_BY_DTYPE,
TorchPadding,
clipped,
)


def solarize(img_batch, threshold):
Expand Down Expand Up @@ -68,8 +69,14 @@ def cutout(img_batch, num_holes, hole_size, fill_value=0):
img_batch = img_batch.clone()
height, width = img_batch.shape[-2:]
for _n in range(num_holes):
y1 = random.randint(0, height - hole_size)
x1 = random.randint(0, width - hole_size)
if height == hole_size:
y1 = torch.tensor([0])
else:
y1 = torch.randint(0, height - hole_size, (1,))
if width == hole_size:
x1 = torch.tensor([0])
else:
x1 = torch.randint(0, width - hole_size, (1,))
y2 = y1 + hole_size
x2 = x1 + hole_size
img_batch[:, :, y1:y2, x1:x2] = fill_value
Expand Down
2 changes: 2 additions & 0 deletions autoalbument/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from autoalbument.callbacks.monitor_average_parameter_change import *
from autoalbument.callbacks.save_policy import *
65 changes: 65 additions & 0 deletions autoalbument/callbacks/monitor_average_parameter_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import copy
import logging

import torch
from pytorch_lightning import Callback

__all__ = [
"MonitorAverageParameterChange",
]


log = logging.getLogger(__name__)


class MonitorAverageParameterChange(Callback):
def __init__(self):
self.policy_model_state_dict_epoch_start = None

@classmethod
def get_combined_parameter_values(cls, state_dict, prefixes):
combined_parameter_values = []
for prefix in prefixes:
weights = state_dict[f"{prefix}._weights"]
temperature = torch.zeros(weights.size(), dtype=weights.dtype).to(weights.device)
for i in range(len(weights)):
temperature_key = f"{prefix}.operations.{i}.temperature"
temperature[i] = state_dict[temperature_key]
weights = torch.div(weights, temperature).softmax(0)
for i, weight in enumerate(weights):
probability_key = f"{prefix}.operations.{i}._probability"
magnitude_key = f"{prefix}.operations.{i}._magnitude"
probability = (weight * state_dict[probability_key].clamp(0.0, 1.0)).item()
magnitude = state_dict.get(magnitude_key)
if magnitude is not None:
magnitude = magnitude.clamp(0.0, 1.0).item()
value = probability * magnitude if magnitude is not None else probability
combined_parameter_values.append(value)
return combined_parameter_values

@classmethod
def get_average_parameter_change(cls, state_dict_1, state_dict_2):
prefixes = {".".join(key.split(".")[:4]) for key in state_dict_1.keys() if key.startswith("sub_policies")}
values_1 = cls.get_combined_parameter_values(state_dict_1, prefixes)
values_2 = cls.get_combined_parameter_values(state_dict_2, prefixes)
values_change = [abs(v2 - v1) for v1, v2 in zip(values_1, values_2)]
return sum(values_change) / len(values_change)

def on_epoch_start(self, trainer, pl_module):
self.policy_model_state_dict_epoch_start = copy.deepcopy(pl_module.policy_model.state_dict())

def on_epoch_end(self, trainer, pl_module):
policy_model_state_dict_epoch_end = pl_module.policy_model.state_dict()
average_parameter_change = self.get_average_parameter_change(
self.policy_model_state_dict_epoch_start,
policy_model_state_dict_epoch_end,
)
pl_module.log(
"average_parameter_change",
average_parameter_change,
on_step=False,
on_epoch=True,
prog_bar=False,
logger=True,
)
log.info(f"Average Parameter change at epoch {trainer.current_epoch}: {average_parameter_change:.6f}.")
36 changes: 36 additions & 0 deletions autoalbument/callbacks/save_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging
import os
import shutil

import albumentations as A
from pytorch_lightning import Callback

__all__ = [
"SavePolicy",
]


log = logging.getLogger(__name__)


class SavePolicy(Callback):
def __init__(self, dirpath=None, latest_policy_filename="latest.json"):
self.dirpath = dirpath or os.path.join(os.getcwd(), "policy")
self.latest_policy_filepath = os.path.join(self.dirpath, latest_policy_filename)
os.makedirs(self.dirpath, exist_ok=True)

def on_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
datamodule = trainer.datamodule
cfg = pl_module.cfg
transform = pl_module.policy_model.create_transform(
input_dtype=cfg.data.input_dtype,
preprocessing_transforms=datamodule.get_preprocessing_transforms(),
)
policy_file_filepath = os.path.join(self.dirpath, f"epoch_{epoch}.json")
A.save(transform, policy_file_filepath)
shutil.copy2(policy_file_filepath, self.latest_policy_filepath)
log.info(
f"Policy is saved to {policy_file_filepath}. "
f"{self.latest_policy_filepath} now also contains this policy."
)
5 changes: 5 additions & 0 deletions autoalbument/cli/conf/_version.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_

_version: 2 # An internal value that indicates a version of the config schema. This value is used by
# `autoalbument-search` and `autoalbument-migrate` to upgrade the config to the latest version if necessary.
# Please do not change it manually.
19 changes: 19 additions & 0 deletions autoalbument/cli/conf/callbacks/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _group_

# A list of PyTorch Lightning callbacks. Documentation on callbacks is available at
# https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html

- _target_: autoalbument.callbacks.MonitorAverageParameterChange
# Prints the "Average Parameter Change" metric at the end of each epoch.
# Read more about this metric at https://albumentations.ai/docs/autoalbument/metrics/#average-parameter-change

- _target_: autoalbument.callbacks.SavePolicy
# Saves augmentation policies at the end of each epoch. You can load saved policies with Albumentations to create
# an augmentation pipeline.

- _target_: pytorch_lightning.callbacks.ModelCheckpoint
save_last: true
dirpath: checkpoints
# Saves a checkpoint at the end of each epoch. The checkpoint will contain all the necessary data to resume training.
# More information about this checkpoint -
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html
22 changes: 22 additions & 0 deletions autoalbument/cli/conf/classification_model/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @package _group_

# Settings for Classification Model that is used for two purposes:
# 1. As a model that performs classification of input images.
# 2. As a Discriminator for Policy Model.

_target_: autoalbument.faster_autoaugment.models.ClassificationModel
# Python class for instantiating Classification Model. You can read more about overriding this value
# to use a custom model at https://albumentations.ai/docs/autoalbument/custom_model/

num_classes: _MISSING_
# Number of classes in the dataset. The dataset implementation should return an integer in the range
# [0, num_classes - 1] as a class label of an image.

architecture: resnet18
# Architecture of Classification Model. The default implementation of Classification model in AutoAlbument uses
# models from https://github.com/rwightman/pytorch-image-models/. Please refer to its documentation to get a list of
# available models - https://rwightman.github.io/pytorch-image-models/#list-models-with-pretrained-weights.

pretrained: False
# Boolean flag that indicates whether the selected model architecture should load pretrained weights or use randomly
# initialized weights.
19 changes: 10 additions & 9 deletions autoalbument/cli/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
defaults:
# Settings for Policy Model that performs search for augmentation policies.
- _version
- task
- policy_model: default

# Settings for Semantic Segmentation Model that is used for two purposes:
# 1. As a model that performs semantic segmentation of input images.
# 2. As a Discriminator for Policy Model.
- classification_model: default
- semantic_segmentation_model: default

- optim: default
- data: default

- default
- searcher: default
- trainer: default
- optim: default
- callbacks: default
- logger: default
- hydra: default
- seed
- search
21 changes: 13 additions & 8 deletions autoalbument/cli/conf/data/default.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# @package _group_

dataset:
_target_: dataset.SearchDataset
# Class for instantiating a PyTorch dataset.

input_dtype: uint8
# The data type of input images. Two values are supported:
# - uint8. In that case, all input images should be NumPy arrays with the np.uint8 data type and values in the range
# [0, 255].
# - float32. In that case, all input images should be NumPy arrays with the np.float32 data type and values in the
# range [0.0, 1.0].
input_dtype: uint8

preprocessing: null
# A list of preprocessing augmentations that will be applied to each image before applying augmentations from
# a policy. A preprocessing augmentation should be defined as `key`: `value`, where `key` is the name of augmentation
# from Albumentations, and `value` is a dictionary with augmentation parameters. The found policy will also apply
Expand All @@ -26,22 +31,22 @@ input_dtype: uint8
# height: 224
# width: 224
#
preprocessing: null

normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
# Normalization values for images. For each image, the search pipeline will subtract `mean` and divide by `std`.
# Normalization is applied after transforms defined in `preprocessing`. Note that regardless of `input_dtype`,
# the normalization function will always receive a `float32` input with values in the range [0.0, 1.0], so you should
# define `mean` and `std` values accordingly. ImageNet normalization is used by default.
normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

# Parameters for the PyTorch DataLoader. Please refer to the PyTorch documentation for the description of parameters -
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader.

dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 128
batch_size: 16
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
# Parameters for the PyTorch DataLoader. Please refer to the PyTorch documentation for the description of parameters -
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader.
21 changes: 0 additions & 21 deletions autoalbument/cli/conf/default.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions autoalbument/cli/conf/hydra/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _group_

run:
dir: ${config_dir:}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
# Path to the directory that will contain all outputs produced by the search algorithm. `${config_dir:}` contains
# path to the directory with the `search.yaml` config file. Please refer to the Hydra documentation for more
# information - https://hydra.cc/docs/configure_hydra/workdir.
8 changes: 8 additions & 0 deletions autoalbument/cli/conf/logger/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _group_

# Configuration for a PyTorch Lightning logger.
# You can read more about loggers at https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html
# By default, TensorBoardLogger is used.

_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: ${config_dir:}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}/tensorboard_logs
Loading

0 comments on commit 784fb71

Please sign in to comment.