Skip to content

Commit

Permalink
Merge pull request #25 from ENSTA-U2IS/rework-baselines
Browse files Browse the repository at this point in the history
Rework baselines
  • Loading branch information
alafage authored Jun 2, 2023
2 parents a99d957 + e93ce38 commit 94461c1
Show file tree
Hide file tree
Showing 56 changed files with 1,182 additions and 1,714 deletions.
29 changes: 20 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ _TorchUncertainty_ is a package designed to help you leverage uncertainty quanti
---

This package provides a multi-level API, including:

- ready-to-train baselines on research datasets, such as ImageNet and CIFAR
- baselines available for training on your datasets
- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧).
Expand All @@ -38,16 +39,24 @@ Please find the documentation at [torch-uncertainty.github.io](https://torch-unc

A quickstart is available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html).

## Implemented baselines
## Implemented methods

### Baselines

To date, the following baselines are implemented:

- Deep Ensembles
- BatchEnsemble
- Masksembles
- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873))

## Tutorials
### Post-processing methods

To date, the following post-processing methods are implemented:

- Temperature scaling

## Tutorials

## Awesome Uncertainty repositories

Expand All @@ -58,10 +67,12 @@ You may find a lot of information about modern uncertainty estimation techniques
This package also contains the official implementation of Packed-Ensembles.

If you find the corresponding models interesting, please consider citing our [paper](https://arxiv.org/abs/2210.09184):

@inproceedings{laurent2023packed,
title={Packed-Ensembles for Efficient Uncertainty Estimation},
author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni},
booktitle={ICLR},
year={2023}
}

```text
@inproceedings{laurent2023packed,
title={Packed-Ensembles for Efficient Uncertainty Estimation},
author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni},
booktitle={ICLR},
year={2023}
}
```
38 changes: 2 additions & 36 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ This API provides lightning-based models that can be easily trained and evaluate

.. currentmodule:: torch_uncertainty.baselines

Vanilla
^^^^^^^
Classification
^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
Expand All @@ -21,40 +21,6 @@ Vanilla
ResNet
WideResNet

Packed-Ensembles
^^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

PackedResNet
PackedWideResNet

Masksembles
^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

MaskedResNet
MaskedWideResNet

BatchEnsemble
^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst


BatchedResNet
BatchedWideResNet

Models
------

Expand Down
23 changes: 0 additions & 23 deletions experiments/batched/resnet18.py

This file was deleted.

35 changes: 35 additions & 0 deletions experiments/classification/cifar10/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(ResNet, CIFAR10DataModule)

net_name = f"{args.version}-resnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR10DataModule(**vars(args))

# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
35 changes: 35 additions & 0 deletions experiments/classification/cifar10/wideresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import WideResNet
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(WideResNet, CIFAR10DataModule)

net_name = f"{args.version}-wideresnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR10DataModule(**vars(args))

# model
model = WideResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
35 changes: 35 additions & 0 deletions experiments/classification/cifar100/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(ResNet, CIFAR100DataModule)

net_name = f"{args.version}-resnet{args.arch}-cifar100"

# datamodule
args.root = str(root / "data")
dm = CIFAR100DataModule(**vars(args))

# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar100", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
35 changes: 35 additions & 0 deletions experiments/classification/cifar100/wideresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# fmt: off
from pathlib import Path

import torch.nn as nn

from torch_uncertainty import cls_main, init_args
from torch_uncertainty.baselines import WideResNet
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.optimization_procedures import get_procedure

# fmt: on
if __name__ == "__main__":
root = Path(__file__).parent.absolute().parents[2]

args = init_args(WideResNet, CIFAR100DataModule)

net_name = f"{args.version}-wideresnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR100DataModule(**vars(args))

# model
model = WideResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar100", args.version
),
imagenet_structure=False,
**vars(args),
)

cls_main(model, dm, root, net_name, args)
15 changes: 15 additions & 0 deletions experiments/classification/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Classification Benchmarks

*Work in progress*

## Image Classification

### CIFAR-10

* ResNet
* WideResNet

### CIFAR-100

* ResNet
* WideResNet
67 changes: 0 additions & 67 deletions experiments/experiments.py

This file was deleted.

23 changes: 0 additions & 23 deletions experiments/masked/resnet18.py

This file was deleted.

Loading

0 comments on commit 94461c1

Please sign in to comment.