Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework baselines #25

Merged
merged 23 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
953cc97
Update Classification routine to be no more abstract :hammer:
alafage May 27, 2023
2c38c67
Add unique baseline for ResNet :hammer:
alafage May 27, 2023
663e0b4
Add unique experiment file for all ResNet on CIFAR10 :sparkles:
alafage May 27, 2023
5552cda
Update WideResNet baseline :hammer:
alafage May 29, 2023
406adfc
Update experiments :hammer:
alafage May 29, 2023
662366c
Update tests accordingly to changes :hammer:
alafage May 29, 2023
1d1db24
Update API reference :books:
alafage May 29, 2023
a7c8021
Add support for BatchEnsemble :sparkles:
o-laurent May 30, 2023
c80d205
Add BastchEnsembles & TempScaling to Rdme :book:
o-laurent May 30, 2023
f969be2
Add wideresnet experiments :sparkles:
o-laurent May 30, 2023
981d7ea
Fix BatchEnsembles optimizer :bug:
o-laurent May 30, 2023
3dfbe5c
Fix experiment name in cifar100 :bug:
o-laurent May 30, 2023
a9d3c26
Revert del. of PL override of None num_epochs => 1k epochs :hammer:
o-laurent May 30, 2023
e9ae777
Use get_procedure in opt. proc. tests :heavy_check_mark:
o-laurent May 30, 2023
2147bca
Second CLI test with different arguments :heavy_check_mark:
o-laurent May 30, 2023
5236c70
Factorize OOD criterion arguments :hammer:
o-laurent May 31, 2023
2b3c16a
Fix num_estimator duplicate :bug:
o-laurent May 31, 2023
cd06cd1
Add groups to all networks :sparkles:
o-laurent May 31, 2023
25823ae
Polish baselines and layer argument checks :hammer:
alafage Jun 2, 2023
bbee3bd
Add docstrings to baselines :bulb:
alafage Jun 2, 2023
c0c4481
Simplify parser arguments for baselines :hammer:
alafage Jun 2, 2023
945cbe5
Solve review comments :ok_hand:
o-laurent Jun 2, 2023
e93ce38
Add forgotten consistency check :ok_hand:
o-laurent Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ _TorchUncertainty_ is a package designed to help you leverage uncertainty quanti
---

This package provides a multi-level API, including:

- ready-to-train baselines on research datasets, such as ImageNet and CIFAR
- baselines available for training on your datasets
- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧).
Expand All @@ -38,16 +39,24 @@ Please find the documentation at [torch-uncertainty.github.io](https://torch-unc

A quickstart is available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html).

## Implemented baselines
## Implemented methods

### Baselines

To date, the following baselines are implemented:

- Deep Ensembles
- BatchEnsemble
- Masksembles
- Packed-Ensembles

## Tutorials
### Post-processing methods

To date, the following post-processing methods are implemented:

- Temperature scaling

## Tutorials

## Awesome Uncertainty repositories

Expand All @@ -58,10 +67,12 @@ You may find a lot of information about modern uncertainty estimation techniques
This package also contains the official implementation of Packed-Ensembles.

If you find the corresponding models interesting, please consider citing our [paper](https://arxiv.org/abs/2210.09184):

@inproceedings{laurent2023packed,
title={Packed-Ensembles for Efficient Uncertainty Estimation},
author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni},
booktitle={ICLR},
year={2023}
}

```text
@inproceedings{laurent2023packed,
title={Packed-Ensembles for Efficient Uncertainty Estimation},
author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni},
booktitle={ICLR},
year={2023}
}
```
38 changes: 2 additions & 36 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ This API provides lightning-based models that can be easily trained and evaluate

.. currentmodule:: torch_uncertainty.baselines

Vanilla
^^^^^^^
Classification
^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
Expand All @@ -21,40 +21,6 @@ Vanilla
ResNet
WideResNet

Packed-Ensembles
^^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

PackedResNet
PackedWideResNet

Masksembles
^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

MaskedResNet
MaskedWideResNet

BatchEnsemble
^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst


BatchedResNet
BatchedWideResNet

Models
------

Expand Down
23 changes: 0 additions & 23 deletions experiments/batched/resnet18.py

This file was deleted.

35 changes: 35 additions & 0 deletions experiments/classification/cifar10/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(ResNet, CIFAR10DataModule)

net_name = f"{args.version}-resnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR10DataModule(**vars(args))

# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
35 changes: 35 additions & 0 deletions experiments/classification/cifar10/wideresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import WideResNet
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(WideResNet, CIFAR10DataModule)

net_name = f"{args.version}-wideresnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR10DataModule(**vars(args))

# model
model = WideResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
35 changes: 35 additions & 0 deletions experiments/classification/cifar100/resnet.py
o-laurent marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(ResNet, CIFAR100DataModule)

net_name = f"{args.version}-resnet{args.arch}-cifar100"

# datamodule
args.root = str(root / "data")
dm = CIFAR100DataModule(**vars(args))

# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar100", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
35 changes: 35 additions & 0 deletions experiments/classification/cifar100/wideresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import WideResNet
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(WideResNet, CIFAR100DataModule)

net_name = f"{args.version}-wideresnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR100DataModule(**vars(args))

# model
model = WideResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar100", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
15 changes: 15 additions & 0 deletions experiments/classification/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Classification Benchmarks
o-laurent marked this conversation as resolved.
Show resolved Hide resolved

*Work in progress*

## Image Classification

### CIFAR-10

* ResNet
* WideResNet

### CIFAR-100

* ResNet
* WideResNet
67 changes: 0 additions & 67 deletions experiments/experiments.py

This file was deleted.

23 changes: 0 additions & 23 deletions experiments/masked/resnet18.py

This file was deleted.

Loading