-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate AutoAlbument core to PyTorch Lightning (#13)
- Loading branch information
Showing
102 changed files
with
1,536 additions
and
1,858 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ dist/ | |
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,26 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 20.8b1 | ||
hooks: | ||
- id: black | ||
args: [--config=black.toml] | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v2.3.0 | ||
rev: v3.4.0 | ||
hooks: | ||
- id: flake8 | ||
- id: trailing-whitespace | ||
- id: check-yaml | ||
- id: end-of-file-fixer | ||
- id: requirements-txt-fixer | ||
- repo: https://gitlab.com/pycqa/flake8 | ||
rev: 3.8.4 | ||
hooks: | ||
- id: flake8 | ||
additional_dependencies: [flake8-docstrings==1.5.0] | ||
- repo: https://github.com/pycqa/isort | ||
rev: 5.7.0 | ||
hooks: | ||
- id: isort | ||
- repo: https://github.com/psf/black | ||
rev: 20.8b1 | ||
hooks: | ||
- id: black | ||
args: [--config=black.toml] | ||
- repo: https://github.com/pre-commit/mirrors-mypy | ||
rev: v0.790 | ||
rev: v0.800 | ||
hooks: | ||
- id: mypy | ||
files: ^autoalbument/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.3.1" | ||
__version__ = "0.5.0" |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from autoalbument.callbacks.monitor_average_parameter_change import * | ||
from autoalbument.callbacks.save_policy import * |
65 changes: 65 additions & 0 deletions
65
autoalbument/callbacks/monitor_average_parameter_change.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import copy | ||
import logging | ||
|
||
import torch | ||
from pytorch_lightning import Callback | ||
|
||
__all__ = [ | ||
"MonitorAverageParameterChange", | ||
] | ||
|
||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class MonitorAverageParameterChange(Callback): | ||
def __init__(self): | ||
self.policy_model_state_dict_epoch_start = None | ||
|
||
@classmethod | ||
def get_combined_parameter_values(cls, state_dict, prefixes): | ||
combined_parameter_values = [] | ||
for prefix in prefixes: | ||
weights = state_dict[f"{prefix}._weights"] | ||
temperature = torch.zeros(weights.size(), dtype=weights.dtype).to(weights.device) | ||
for i in range(len(weights)): | ||
temperature_key = f"{prefix}.operations.{i}.temperature" | ||
temperature[i] = state_dict[temperature_key] | ||
weights = torch.div(weights, temperature).softmax(0) | ||
for i, weight in enumerate(weights): | ||
probability_key = f"{prefix}.operations.{i}._probability" | ||
magnitude_key = f"{prefix}.operations.{i}._magnitude" | ||
probability = (weight * state_dict[probability_key].clamp(0.0, 1.0)).item() | ||
magnitude = state_dict.get(magnitude_key) | ||
if magnitude is not None: | ||
magnitude = magnitude.clamp(0.0, 1.0).item() | ||
value = probability * magnitude if magnitude is not None else probability | ||
combined_parameter_values.append(value) | ||
return combined_parameter_values | ||
|
||
@classmethod | ||
def get_average_parameter_change(cls, state_dict_1, state_dict_2): | ||
prefixes = {".".join(key.split(".")[:4]) for key in state_dict_1.keys() if key.startswith("sub_policies")} | ||
values_1 = cls.get_combined_parameter_values(state_dict_1, prefixes) | ||
values_2 = cls.get_combined_parameter_values(state_dict_2, prefixes) | ||
values_change = [abs(v2 - v1) for v1, v2 in zip(values_1, values_2)] | ||
return sum(values_change) / len(values_change) | ||
|
||
def on_epoch_start(self, trainer, pl_module): | ||
self.policy_model_state_dict_epoch_start = copy.deepcopy(pl_module.policy_model.state_dict()) | ||
|
||
def on_epoch_end(self, trainer, pl_module): | ||
policy_model_state_dict_epoch_end = pl_module.policy_model.state_dict() | ||
average_parameter_change = self.get_average_parameter_change( | ||
self.policy_model_state_dict_epoch_start, | ||
policy_model_state_dict_epoch_end, | ||
) | ||
pl_module.log( | ||
"average_parameter_change", | ||
average_parameter_change, | ||
on_step=False, | ||
on_epoch=True, | ||
prog_bar=False, | ||
logger=True, | ||
) | ||
log.info(f"Average Parameter change at epoch {trainer.current_epoch}: {average_parameter_change:.6f}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import logging | ||
import os | ||
import shutil | ||
|
||
import albumentations as A | ||
from pytorch_lightning import Callback | ||
|
||
__all__ = [ | ||
"SavePolicy", | ||
] | ||
|
||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class SavePolicy(Callback): | ||
def __init__(self, dirpath=None, latest_policy_filename="latest.json"): | ||
self.dirpath = dirpath or os.path.join(os.getcwd(), "policy") | ||
self.latest_policy_filepath = os.path.join(self.dirpath, latest_policy_filename) | ||
os.makedirs(self.dirpath, exist_ok=True) | ||
|
||
def on_epoch_end(self, trainer, pl_module): | ||
epoch = trainer.current_epoch | ||
datamodule = trainer.datamodule | ||
cfg = pl_module.cfg | ||
transform = pl_module.policy_model.create_transform( | ||
input_dtype=cfg.data.input_dtype, | ||
preprocessing_transforms=datamodule.get_preprocessing_transforms(), | ||
) | ||
policy_file_filepath = os.path.join(self.dirpath, f"epoch_{epoch}.json") | ||
A.save(transform, policy_file_filepath) | ||
shutil.copy2(policy_file_filepath, self.latest_policy_filepath) | ||
log.info( | ||
f"Policy is saved to {policy_file_filepath}. " | ||
f"{self.latest_policy_filepath} now also contains this policy." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# @package _global_ | ||
|
||
_version: 2 # An internal value that indicates a version of the config schema. This value is used by | ||
# `autoalbument-search` and `autoalbument-migrate` to upgrade the config to the latest version if necessary. | ||
# Please do not change it manually. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _group_ | ||
|
||
# A list of PyTorch Lightning callbacks. Documentation on callbacks is available at | ||
# https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html | ||
|
||
- _target_: autoalbument.callbacks.MonitorAverageParameterChange | ||
# Prints the "Average Parameter Change" metric at the end of each epoch. | ||
# Read more about this metric at https://albumentations.ai/docs/autoalbument/metrics/#average-parameter-change | ||
|
||
- _target_: autoalbument.callbacks.SavePolicy | ||
# Saves augmentation policies at the end of each epoch. You can load saved policies with Albumentations to create | ||
# an augmentation pipeline. | ||
|
||
- _target_: pytorch_lightning.callbacks.ModelCheckpoint | ||
save_last: true | ||
dirpath: checkpoints | ||
# Saves a checkpoint at the end of each epoch. The checkpoint will contain all the necessary data to resume training. | ||
# More information about this checkpoint - | ||
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# @package _group_ | ||
|
||
# Settings for Classification Model that is used for two purposes: | ||
# 1. As a model that performs classification of input images. | ||
# 2. As a Discriminator for Policy Model. | ||
|
||
_target_: autoalbument.faster_autoaugment.models.ClassificationModel | ||
# Python class for instantiating Classification Model. You can read more about overriding this value | ||
# to use a custom model at https://albumentations.ai/docs/autoalbument/custom_model/ | ||
|
||
num_classes: _MISSING_ | ||
# Number of classes in the dataset. The dataset implementation should return an integer in the range | ||
# [0, num_classes - 1] as a class label of an image. | ||
|
||
architecture: resnet18 | ||
# Architecture of Classification Model. The default implementation of Classification model in AutoAlbument uses | ||
# models from https://github.com/rwightman/pytorch-image-models/. Please refer to its documentation to get a list of | ||
# available models - https://rwightman.github.io/pytorch-image-models/#list-models-with-pretrained-weights. | ||
|
||
pretrained: False | ||
# Boolean flag that indicates whether the selected model architecture should load pretrained weights or use randomly | ||
# initialized weights. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
defaults: | ||
# Settings for Policy Model that performs search for augmentation policies. | ||
- _version | ||
- task | ||
- policy_model: default | ||
|
||
# Settings for Semantic Segmentation Model that is used for two purposes: | ||
# 1. As a model that performs semantic segmentation of input images. | ||
# 2. As a Discriminator for Policy Model. | ||
- classification_model: default | ||
- semantic_segmentation_model: default | ||
|
||
- optim: default | ||
- data: default | ||
|
||
- default | ||
- searcher: default | ||
- trainer: default | ||
- optim: default | ||
- callbacks: default | ||
- logger: default | ||
- hydra: default | ||
- seed | ||
- search |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# @package _group_ | ||
|
||
run: | ||
dir: ${config_dir:}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} | ||
# Path to the directory that will contain all outputs produced by the search algorithm. `${config_dir:}` contains | ||
# path to the directory with the `search.yaml` config file. Please refer to the Hydra documentation for more | ||
# information - https://hydra.cc/docs/configure_hydra/workdir. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# @package _group_ | ||
|
||
# Configuration for a PyTorch Lightning logger. | ||
# You can read more about loggers at https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html | ||
# By default, TensorBoardLogger is used. | ||
|
||
_target_: pytorch_lightning.loggers.TensorBoardLogger | ||
save_dir: ${config_dir:}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}/tensorboard_logs |
Oops, something went wrong.