From 870af02d610d5e18c92d550c95f9df5e3f2f68f8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 29 May 2024 18:04:38 +0200 Subject: [PATCH 01/57] :sparkles: Add LeNet experiment on MNIST --- .../classification/mnist/configs/lenet.yaml | 57 +++++++++++++++++ experiments/classification/mnist/lenet.py | 61 ++++++------------- 2 files changed, 75 insertions(+), 43 deletions(-) create mode 100644 experiments/classification/mnist/configs/lenet.yaml diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml new file mode 100644 index 00000000..c4635dfe --- /dev/null +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -0,0 +1,57 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index 450f72c2..ffb2b8f2 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -1,52 +1,27 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.routines.classification import ClassificationSingle +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TULightningCLI -def optim_lenet(model: nn.Module) -> dict: - """Optimization recipe for LeNet. +class MNISTCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) - Uses Adam default hyperparameters. - Args: - model (nn.Module): LeNet model. - """ - return { - "optimizer": optim.Adam( - model.parameters(), - ) - } +def cli_main() -> MNISTCLI: + return MNISTCLI(ClassificationRoutine, MNISTDataModule) if __name__ == "__main__": - args = init_args(datamodule=MNISTDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - if args.exp_name == "": - args.exp_name = "std-lenet-mnist" - - # datamodule - args.root = str(root / "data") - dm = MNISTDataModule(**vars(args)) - - # model - model = lenet(dm.num_channels, dm.num_classes) - - baseline = ClassificationSingle( - model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss(), - optim_recipe=optim_lenet, - **vars(args), - ) - - cli_main(baseline, dm, args.exp_dir, args.exp_name, args) + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") From ec61536ecc21ce9ec5c4c87de6a665db5d2823c6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 30 May 2024 15:55:17 +0200 Subject: [PATCH 02/57] :bug: Fix notMNIST --- torch_uncertainty/datasets/classification/not_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index 71d2bc6a..e15ac640 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -66,7 +66,7 @@ def __init__( ) super().__init__( - self.root + f"/notMNIST_{subset}", + self.root / ("notMNIST_" + subset), transform=transform, target_transform=target_transform, ) @@ -97,4 +97,4 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: Args: index (int): The index of the sample to get. """ - return super().__getitem__(index)[0] + return super().__getitem__(index) From d2689114f724fcf245b3e9c9a1fb0667c090a389 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 31 May 2024 11:46:32 +0200 Subject: [PATCH 03/57] :bug: Fix MNIST datamodule OODs --- .../classification/test_mnist_datamodule.py | 7 ++++++- .../datamodules/classification/mnist.py | 21 +++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/datamodules/classification/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist_datamodule.py index 1707409a..d0517415 100644 --- a/tests/datamodules/classification/test_mnist_datamodule.py +++ b/tests/datamodules/classification/test_mnist_datamodule.py @@ -19,7 +19,11 @@ def test_mnist_cutout(self): assert isinstance(dm.train_transform.transforms[0], Cutout) dm = MNISTDataModule( - root="./data/", batch_size=128, ood_ds="not", cutout=0, val_split=0 + root="./data/", + batch_size=128, + ood_ds="notMNIST", + cutout=0, + val_split=0, ) assert isinstance(dm.train_transform.transforms[0], nn.Identity) @@ -42,6 +46,7 @@ def test_mnist_cutout(self): dm.setup("other") dm.eval_ood = True + dm.ood_transform = dm.test_transform dm.val_split = 0.1 dm.prepare_data() dm.setup() diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 77a6f4f5..b18b279d 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -17,14 +17,14 @@ class MNISTDataModule(AbstractDataModule): num_channels = 1 input_shape = (1, 28, 28) training_task = "classification" - ood_datasets = ["fashion", "not"] + ood_datasets = ["fashion", "notMNIST"] def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, - ood_ds: Literal["fashion", "not"] = "fashion", + ood_ds: Literal["fashion", "notMNIST"] = "fashion", val_split: float | None = None, num_workers: int = 1, cutout: int | None = None, @@ -39,7 +39,7 @@ def __init__( eval_ood (bool): Whether to evaluate on out-of-distribution data. batch_size (int): Number of samples per batch. ood_ds (str): Which out-of-distribution dataset to use. Defaults to - ``"fashion"``; `fashion` stands for FashionMNIST and `not` for + ``"fashion"``; `fashion` stands for FashionMNIST and `notMNIST` for notMNIST. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. @@ -71,11 +71,11 @@ def __init__( if ood_ds == "fashion": self.ood_dataset = FashionMNIST - elif ood_ds == "not": + elif ood_ds == "notMNIST": self.ood_dataset = NotMNIST else: raise ValueError( - f"`ood_ds` should be `fashion` or `not`. Got {ood_ds}." + f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}." ) main_transform = Cutout(cutout) if cutout else nn.Identity() @@ -95,6 +95,15 @@ def __init__( T.Normalize((0.1307,), (0.3081,)), ] ) + if self.eval_ood: # NotMNIST has 3 channels + self.ood_transform = T.Compose( + [ + T.Grayscale(num_output_channels=1), + T.ToTensor(), + T.CenterCrop(28), + T.Normalize((0.1307,), (0.3081,)), + ] + ) def prepare_data(self) -> None: # coverage: ignore """Download the datasets.""" @@ -140,7 +149,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: self.ood = self.ood_dataset( self.root, download=False, - transform=self.test_transform, + transform=self.ood_transform, ) def test_dataloader(self) -> list[DataLoader]: From 19fafbdce49ae404f85a35a88818f513da0c8ade Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 2 Jun 2024 16:42:55 +0200 Subject: [PATCH 04/57] :sparkles: Add Laplace wrapper --- docs/source/api.rst | 11 ++- pyproject.toml | 7 +- torch_uncertainty/post_processing/laplace.py | 68 +++++++++++++++++++ .../post_processing/mc_batch_norm.py | 2 +- 4 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 torch_uncertainty/post_processing/laplace.py diff --git a/docs/source/api.rst b/docs/source/api.rst index d4f99acf..a3762b92 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -242,6 +242,16 @@ Post-Processing Methods .. currentmodule:: torch_uncertainty.post_processing +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class_inherited.rst + MCBatchNorm + Laplace + +Scaling Methods +^^^^^^^^^^^^^^^ + .. autosummary:: :toctree: generated/ :nosignatures: @@ -250,7 +260,6 @@ Post-Processing Methods TemperatureScaler VectorScaler MatrixScaler - MCBatchNorm Datamodules ----------- diff --git a/pyproject.toml b/pyproject.toml index 0b11a230..822924d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ ] [project.optional-dependencies] -image = ["scikit-image", "h5py",] +image = ["scikit-image", "h5py", "webdataset"] tabular = ["pandas"] dev = [ "torch_uncertainty[image]", @@ -63,7 +63,10 @@ docs = [ "sphinx-design", "sphinx-codeautolink", ] -all = ["torch_uncertainty[dev,docs,image,tabular]"] +all = [ + "torch_uncertainty[dev,docs,image,tabular]", + "laplace-torch" + ] [project.urls] homepage = "https://torch-uncertainty.github.io/" diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py new file mode 100644 index 00000000..e3f19d13 --- /dev/null +++ b/torch_uncertainty/post_processing/laplace.py @@ -0,0 +1,68 @@ +from importlib import util +from typing import Literal + +from torch import Tensor, nn +from torch.utils.data import Dataset + +if util.find_spec("laplace"): + from laplace import Laplace + + laplace_installed = True + + +class Laplace(nn.Module): + def __init__( + self, + model: nn.Module, + task: Literal["classification", "regression"], + subset_of_weights="last_layer", + hessian_structure="kron", + pred_type: Literal["glm", "nn"] = "glm", + link_approx: Literal[ + "mc", "probit", "bridge", "bridge_norm" + ] = "probit", + ) -> None: + """Laplace approximation for uncertainty estimation. + + This class is a wrapper of Laplace classes from the laplace-torch library. + + Args: + model (nn.Module): model to be converted. + task (Literal["classification", "regression"]): task type. + subset_of_weights (str): subset of weights to be considered. Defaults to + "last_layer". + hessian_structure (str): structure of the Hessian matrix. Defaults to + "kron". + pred_type (Literal["glm", "nn"], optional): type of posterior predictive, + See the Laplace library for more details. Defaults to "glm". + link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): + how to approximate the classification link function for the `'glm'`. + See the Laplace library for more details. Defaults to "probit". + + Reference: + Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. + """ + super().__init__() + if not laplace_installed: + raise ImportError( + "The laplace-torch library is not installed. Please install it via `pip install laplace-torch`." + ) + self.la = Laplace( + model=model, + task=task, + subset_of_weights=subset_of_weights, + hessian_structure=hessian_structure, + ) + self.pred_type = pred_type + self.link_approx = link_approx + + def fit(self, dataset: Dataset) -> None: + self.la.fit(dataset=dataset) + + def forward( + self, + x: Tensor, + ) -> Tensor: + return self.la( + x, pred_type=self.pred_type, link_approx=self.link_approx + ) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index d99fdd7c..66d21889 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -99,7 +99,7 @@ def _est_forward(self, x: Tensor) -> Tensor: def forward( self, x: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tensor: if self.training: return self.model(x) if not self.trained: From acf90ebad3f504fcd11a6fe07f878d9d5104f7e1 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 3 Jun 2024 17:27:43 +0200 Subject: [PATCH 05/57] :books: Add Laplace to the references --- README.md | 1 + docs/source/references.rst | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/README.md b/README.md index 00d0bc2a..ebfe0957 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ To date, the following post-processing methods have been implemented: - Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) - Monte Carlo Batch Normalization - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html) +- A wrapper for Laplace appoximation using the [Laplace library](https://github.com/aleximmer/Laplace) ## Tutorials diff --git a/docs/source/references.rst b/docs/source/references.rst index bd4467c9..b219eb3f 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -193,6 +193,16 @@ For Monte-Carlo Batch Normalization, consider citing: * Authors: *Mathias Teye, Hossein Azizpour, and Kevin Smith* * Paper: `ICML 2018 `__. +Laplace Approximation +^^^^^^^^^^^^^^^^^^^^^ + +For Laplace Approximation, consider citing: + +**Laplace Redux - Effortless Bayesian Deep Learning** + +* Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, Philipp Hennig* +* Paper: `NeurIPS 2021 `__. + Metrics ------- From 63fb8744c4da2ffd75252cccca604d3352271856 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 5 Jun 2024 11:58:55 +0200 Subject: [PATCH 06/57] :hammer: Refactor Mixup params --- tests/_dummies/baseline.py | 22 ++-- tests/routines/test_classification.py | 7 +- .../baselines/classification/resnet.py | 31 +---- .../baselines/classification/vgg.py | 31 +---- .../baselines/classification/wideresnet.py | 31 +---- torch_uncertainty/routines/classification.py | 120 +++++++++--------- 6 files changed, 97 insertions(+), 145 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index c43b444c..5a50cb5e 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -41,6 +41,7 @@ def __new__( kernel_tau_std: float = 0.5, mixup_alpha: float = 0, cutmix_alpha: float = 0, + no_mixup_params: bool = False, ) -> ClassificationRoutine: model = dummy_model( in_channels=in_channels, @@ -48,7 +49,18 @@ def __new__( with_feats=with_feats, with_linear=with_linear, ) - + if not no_mixup_params: + mixup_params = { + "mixup_alpha": mixup_alpha, + "cutmix_alpha": cutmix_alpha, + "mixtype": mixtype, + "mixmode": mixmode, + "dist_sim": dist_sim, + "kernel_tau_max": kernel_tau_max, + "kernel_tau_std": kernel_tau_std, + } + else: + mixup_params = None if baseline_type == "single": return ClassificationRoutine( num_classes=num_classes, @@ -58,13 +70,7 @@ def __new__( log_plots=True, optim_recipe=optim_recipe(model), num_estimators=1, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 22f6cae6..4c40da79 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -81,6 +81,7 @@ def test_one_estimator_two_classes(self): baseline_type="single", ood_criterion="entropy", eval_ood=True, + no_mixup_params=True, ) trainer.fit(model, dm) @@ -366,8 +367,12 @@ def test_classification_failures(self): ) with pytest.raises(ValueError): + mixup_params = {"cutmix_alpha": -1} ClassificationRoutine( - num_classes=10, model=nn.Module(), loss=None, cutmix_alpha=-1 + num_classes=10, + model=nn.Module(), + loss=None, + mixup_params=mixup_params, ) with pytest.raises( diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index ff051b48..bd41cbe6 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -47,13 +47,7 @@ def __init__( style: str = "imagenet", num_estimators: int = 1, dropout_rate: float = 0.0, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, last_layer_dropout: bool = False, groups: int = 1, scale: float | None = None, @@ -108,17 +102,10 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel - tau. Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults - to ``0``. - cutmix_alpha (float, optional): Alpha parameter for CutMix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixtype, + mixmode, dist_sim, kernel_tau_max, kernel_tau_std, + mixup_alpha, and cutmix_alpha. If None, no augmentations. + Defaults to ``None``. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. scale (float, optional): Expansion factor affecting the width of @@ -228,13 +215,7 @@ def __init__( loss=loss, num_estimators=num_estimators, format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index fc4f5256..8988b613 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -32,13 +32,7 @@ def __init__( num_estimators: int = 1, dropout_rate: float = 0.0, last_layer_dropout: bool = False, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, groups: int = 1, alpha: int | None = None, gamma: int = 1, @@ -79,17 +73,10 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"`` or ``"masked"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel - tau. Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults - to ``0``. - cutmix_alpha (float, optional): Alpha parameter for CutMix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixtype, + mixmode, dist_sim, kernel_tau_max, kernel_tau_std, + mixup_alpha, and cutmix_alpha. If None, no augmentations. + Defaults to ``None``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -177,13 +164,7 @@ def __init__( loss=loss, num_estimators=num_estimators, format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, eval_ood=eval_ood, ood_criterion=ood_criterion, log_plots=log_plots, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index ffda0d48..935049cd 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -39,13 +39,7 @@ def __init__( style: str = "imagenet", num_estimators: int = 1, dropout_rate: float = 0.0, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, groups: int = 1, last_layer_dropout: bool = False, scale: float | None = None, @@ -89,17 +83,10 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"`` or ``"masked"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel - tau. Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults - to ``0``. - cutmix_alpha (float, optional): Alpha parameter for CutMix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixtype, + mixmode, dist_sim, kernel_tau_max, kernel_tau_std, + mixup_alpha, and cutmix_alpha. If None, no augmentations. + Defaults to ``None``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -199,13 +186,7 @@ def __init__( loss=loss, num_estimators=num_estimators, format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 0c07b8ab..90f3d2ba 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -37,6 +37,16 @@ from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup from torch_uncertainty.utils import csv_writer, plot_hist +MIXUP_PARAMS = { + "mixtype": "erm", + "mixmode": "elem", + "dist_sim": "emb", + "kernel_tau_max": 1.0, + "kernel_tau_std": 0.5, + "mixup_alpha": 0, + "cutmix_alpha": 0, +} + class ClassificationRoutine(LightningModule): def __init__( @@ -47,13 +57,7 @@ def __init__( num_estimators: int = 1, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, eval_ood: bool = False, eval_grouping_loss: bool = False, ood_criterion: Literal[ @@ -76,16 +80,10 @@ def __init__( Defaults to :class:`torch.nn.Identity()`. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to ``0``. - cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixup type, + mixup mode, distance similarity, kernel tau max, kernel tau std, + mixup alpha, and cutmix alpha. If None, no augmentations. + Defaults to ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection performance or not. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the @@ -214,19 +212,9 @@ def __init__( self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") # Mixup - self.mixtype = mixtype - self.mixmode = mixmode - self.dist_sim = dist_sim - if num_estimators == 1: - if mixup_alpha < 0 or cutmix_alpha < 0: - raise ValueError( - "Cutmix alpha and Mixup alpha must be positive." - f"Got {mixup_alpha} and {cutmix_alpha}." - ) - self.mixup = self.init_mixup( - mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std - ) + if num_estimators == 1: + self.mixup = self.init_mixup(mixup_params) if self.eval_grouping_loss: grouping_loss = MetricCollection( @@ -243,46 +231,51 @@ def __init__( self.id_logit_storage = None self.ood_logit_storage = None - def init_mixup( - self, - mixup_alpha: float, - cutmix_alpha: float, - kernel_tau_max: float, - kernel_tau_std: float, - ) -> Callable: - if self.mixtype == "timm": + def init_mixup(self, mixup_params: dict | None) -> Callable: + if mixup_params is None: + mixup_params = {} + mixup_params = MIXUP_PARAMS | mixup_params + self.mixup_params = mixup_params + + if mixup_params["mixup_alpha"] < 0 or mixup_params["cutmix_alpha"] < 0: + raise ValueError( + "Cutmix alpha and Mixup alpha must be positive." + f"Got {mixup_params['mixup_alpha']} and {mixup_params['cutmix_alpha']}." + ) + + if mixup_params["mixtype"] == "timm": return timm_Mixup( - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, - mode=self.mixmode, + mixup_alpha=mixup_params["mixup_alpha"], + cutmix_alpha=mixup_params["cutmix_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "mixup": + if mixup_params["mixtype"] == "mixup": return Mixup( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "mixup_io": + if mixup_params["mixtype"] == "mixup_io": return MixupIO( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "regmixup": + if mixup_params["mixtype"] == "regmixup": return RegMixup( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "kernel_warping": + if mixup_params["mixtype"] == "kernel_warping": return WarpingMixup( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, apply_kernel=True, - tau_max=kernel_tau_max, - tau_std=kernel_tau_std, + tau_max=mixup_params["kernel_tau_max"], + tau_std=mixup_params["kernel_tau_std"], ) return Identity() @@ -338,22 +331,27 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: logits = self.model(inputs) return logits - def training_step( - self, batch: tuple[Tensor, Tensor], batch_idx: int - ) -> STEP_OUTPUT: - # Mixup only for single models + def apply_mixup( + self, batch: tuple[Tensor, Tensor] + ) -> tuple[Tensor, Tensor]: if self.num_estimators == 1: - if self.mixtype == "kernel_warping": - if self.dist_sim == "emb": + if self.mixup_params["mixtype"] == "kernel_warping": + if self.mixup_params["dist_sim"] == "emb": with torch.no_grad(): feats = self.model.feats_forward(batch[0]).detach() batch = self.mixup(*batch, feats) - elif self.dist_sim == "inp": + elif self.mixup_params["dist_sim"] == "inp": batch = self.mixup(*batch, batch[0]) else: batch = self.mixup(*batch) + return batch + def training_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> STEP_OUTPUT: + # Mixup only for single models + batch = self.apply_mixup(batch) inputs, target = self.format_batch_fn(batch) if self.is_elbo: From acbd5822daf0a093987fc98dfe794726a555ca99 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 6 Jun 2024 16:01:45 +0200 Subject: [PATCH 07/57] :bug: Fix #99 error in calibration plots --- tests/metrics/classification/test_calibration.py | 1 + torch_uncertainty/metrics/classification/calibration_error.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index ee8ab224..3ad5e3f3 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -15,6 +15,7 @@ def test_plot_binary(self) -> None: torch.as_tensor([0, 0, 1, 1, 1]), ) fig, ax = metric.plot() + metric.plot(ax=ax) assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) assert ax.get_xlabel() == "Top-class Confidence (%)" diff --git a/torch_uncertainty/metrics/classification/calibration_error.py b/torch_uncertainty/metrics/classification/calibration_error.py index 5def4d3c..512323b3 100644 --- a/torch_uncertainty/metrics/classification/calibration_error.py +++ b/torch_uncertainty/metrics/classification/calibration_error.py @@ -60,7 +60,8 @@ def _ce_plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: ax.set_xlim(0, 100) ax.set_ylim(0, 100) ax.set_aspect("equal", "box") - fig.tight_layout() + if fig is not None: + fig.tight_layout() return fig, ax From 8c2de92451aa8df4b86a2317cb0763a7b7bccd3d Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Jun 2024 10:29:04 +0200 Subject: [PATCH 08/57] :shirt: Slightly improve dropout --- auto_tutorials_source/tutorial_mc_dropout.py | 13 ++++---- torch_uncertainty/models/lenet.py | 6 ++-- torch_uncertainty/models/mc_dropout.py | 35 ++++++++++++++++---- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 59dd4241..cd410526 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -51,7 +51,8 @@ # dataloaders and transforms. We create the model using the # blueprint from torch_uncertainty.models and we wrap it into mc_dropout. # -# It is important to specify the arguments,``num_estimators`` and the ``dropout_rate`` +# It is important to specify the arguments,``num_estimators`` +# and the ``dropout_rate`` # to use Monte Carlo dropout. trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) @@ -64,7 +65,7 @@ model = lenet( in_channels=datamodule.num_channels, num_classes=datamodule.num_classes, - dropout_rate=0.4, + dropout_rate=0.5, ) mc_model = mc_dropout(model, num_estimators=16, last_layer=False) @@ -118,8 +119,8 @@ def imshow(img): images, labels = next(dataiter) # print images -imshow(torchvision.utils.make_grid(images[:4, ...])) -print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) +imshow(torchvision.utils.make_grid(images[:6, ...], padding=0)) +print("Ground truth labels: ", " ".join(f"{labels[j]}" for j in range(6))) routine.eval() logits = routine(images).reshape(16, 128, 10) @@ -127,7 +128,7 @@ def imshow(img): probs = torch.nn.functional.softmax(logits, dim=-1) -for j in range(4): +for j in range(6): values, predicted = torch.max(probs[:, j], 1) print( f"Predicted digits for the image {j+1}: ", @@ -135,5 +136,5 @@ def imshow(img): ) # %% -# We see that there is some disagreement between the samples of the dropout +# Most of the time, we see that there is some disagreement between the samples of the dropout # approximation of the posterior distribution. diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index b18fa488..34a4da00 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -39,7 +39,9 @@ def __init__( ): batchnorm = True else: - raise ValueError("norm must be nn.Identity or nn.BatchNorm2d") + raise ValueError( + f"norm must be nn.Identity or nn.BatchNorm2d. Got {norm}." + ) self.dropout_rate = dropout_rate self.last_layer_dropout = last_layer_dropout @@ -179,7 +181,7 @@ def bayesian_lenet( norm: type[nn.Module] = nn.Identity, groups: int = 1, dropout_rate: float = 0.0, -) -> _LeNet: +) -> _StochasticLeNet: layers_args = {} if prior_sigma_1 is not None: layers_args["prior_sigma_1"] = prior_sigma_1 diff --git a/torch_uncertainty/models/mc_dropout.py b/torch_uncertainty/models/mc_dropout.py index 24a545b3..ff0d389a 100644 --- a/torch_uncertainty/models/mc_dropout.py +++ b/torch_uncertainty/models/mc_dropout.py @@ -1,9 +1,14 @@ +import torch from torch import Tensor, nn class _MCDropout(nn.Module): def __init__( - self, model: nn.Module, num_estimators: int, last_layer: bool + self, + model: nn.Module, + num_estimators: int, + last_layer: bool, + on_batch: bool, ) -> None: """MC Dropout wrapper for a model containing nn.Dropout modules. @@ -11,6 +16,8 @@ def __init__( model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool): whether to apply dropout to the last layer only. + on_batch (bool): Increase the batch_size to perform MC-Dropout. + Otherwise in a for loop. Warning: Apply dropout using modules and not functional for this wrapper to @@ -27,6 +34,7 @@ def __init__( """ super().__init__() self.last_layer = last_layer + self.on_batch = on_batch if not hasattr(model, "dropout_rate"): raise ValueError( @@ -36,14 +44,14 @@ def __init__( raise ValueError( "`dropout_rate` must be strictly positive to use MC Dropout." ) + self.model = model + if num_estimators is None: raise ValueError("`num_estimators` must be set to use MC Dropout.") if num_estimators <= 0: raise ValueError( "`num_estimators` must be strictly positive to use MC Dropout." ) - - self.model = model self.num_estimators = num_estimators self.filtered_modules = list( @@ -77,12 +85,20 @@ def forward( x: Tensor, ) -> Tensor: if not self.training: - x = x.repeat(self.num_estimators, 1, 1, 1) + if self.on_batch: + x = x.repeat(self.num_estimators, 1, 1, 1) + return self.model(x) + return torch.cat( + [self.model(x) for _ in range(self.num_estimators)], dim=0 + ) return self.model(x) def mc_dropout( - model: nn.Module, num_estimators: int, last_layer: bool = False + model: nn.Module, + num_estimators: int, + last_layer: bool = False, + on_batch: bool = True, ) -> _MCDropout: """MC Dropout wrapper for a model. @@ -91,7 +107,14 @@ def mc_dropout( num_estimators (int): number of estimators to use last_layer (bool, optional): whether to apply dropout to the last layer only. Defaults to False. + on_batch (bool): Increase the batch_size to perform MC-Dropout. + Otherwise in a for loop to reduce memory footprint. Defaults + to true. + """ return _MCDropout( - model=model, num_estimators=num_estimators, last_layer=last_layer + model=model, + num_estimators=num_estimators, + last_layer=last_layer, + on_batch=on_batch, ) From 8655becd5e8c707cc31b0b66398c3d0ce02f3c86 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Jun 2024 10:36:00 +0200 Subject: [PATCH 09/57] :bug: Fix MC Dropout test --- tests/models/test_mc_dropout.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/models/test_mc_dropout.py b/tests/models/test_mc_dropout.py index b0cd9327..69e8e2e5 100644 --- a/tests/models/test_mc_dropout.py +++ b/tests/models/test_mc_dropout.py @@ -27,14 +27,23 @@ def test_mc_dropout_eval(self): assert not dropout_model.training dropout_model(torch.rand(1, 10)) + dropout_model = mc_dropout(model, num_estimators=5, on_batch=False) + dropout_model.eval() + assert not dropout_model.training + dropout_model(torch.rand(1, 10)) + def test_mc_dropout_errors(self): model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): - _MCDropout(model=model, num_estimators=-1, last_layer=True) + _MCDropout( + model=model, num_estimators=-1, last_layer=True, on_batch=True + ) with pytest.raises(ValueError): - _MCDropout(model=model, num_estimators=0, last_layer=False) + _MCDropout( + model=model, num_estimators=0, last_layer=False, on_batch=False + ) dropout_model = mc_dropout(model, 5) with pytest.raises(TypeError): From 91cf1c042c1df1486508238312f4bf54b040aa25 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Jun 2024 15:14:14 +0200 Subject: [PATCH 10/57] :book: Remove Packed-Ensembles mentionned twice --- docs/source/index.rst | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index b0af32c5..0c3c7994 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,7 +32,7 @@ To install TorchUncertainty with contribution in mind, check the ----- Official Implementations -^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^ TorchUncertainty also houses multiple official implementations of papers from major conferences & journals. @@ -56,14 +56,6 @@ TorchUncertainty also houses multiple official implementations of papers from ma * Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, Angel Tena, Rémi Kazmierczak, Séverine Dubuisson, Emanuel Aldea, David Filliat* * Paper: `BMVC 2022 `_. -Packed-Ensembles -^^^^^^^^^^^^^^^^ - -**Packed-Ensembles for Efficient Uncertainty Estimation** - -* Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi* -* Paper: `here `_. - .. toctree:: :maxdepth: 2 :caption: Contents: From 988a89b9f5c0b1d987e3b394133150daca903656 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Jun 2024 17:03:32 +0200 Subject: [PATCH 11/57] :sparkles: Add Trajectory Ensemble --- .../mnist/configs/lenet_trajectory.yaml | 72 ++++++++++++ torch_uncertainty/models/__init__.py | 5 + torch_uncertainty/models/trajectory_models.py | 106 ++++++++++++++++++ torch_uncertainty/routines/classification.py | 17 ++- 4 files changed, 195 insertions(+), 5 deletions(-) create mode 100644 experiments/classification/mnist/configs/lenet_trajectory.yaml create mode 100644 torch_uncertainty/models/trajectory_models.py diff --git a/experiments/classification/mnist/configs/lenet_trajectory.yaml b/experiments/classification/mnist/configs/lenet_trajectory.yaml new file mode 100644 index 00000000..caa41348 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_trajectory.yaml @@ -0,0 +1,72 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.TrajectoryEnsemble + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + save_schedule: + - 20 + - 25 + - 30 + - 35 + - 40 + - 45 + - 50 + - 55 + - 60 + - 65 + - 70 + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 08dfc824..1025c613 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,3 +1,8 @@ # ruff: noqa: F401 from .deep_ensembles import deep_ensembles from .mc_dropout import mc_dropout +from .trajectory_models import ( + TrajectoryEnsemble, + TrajectoryModel, + trajectory_ensemble, +) diff --git a/torch_uncertainty/models/trajectory_models.py b/torch_uncertainty/models/trajectory_models.py new file mode 100644 index 00000000..8fcb32f9 --- /dev/null +++ b/torch_uncertainty/models/trajectory_models.py @@ -0,0 +1,106 @@ +import copy + +import torch +from torch import nn + + +class TrajectoryModel(nn.Module): + def __init__( + self, + model: nn.Module, + save_schedule: list[int] | None = None, + ) -> None: + """Ensemble of models at different points in the training trajectory. + + Args: + model (nn.Module): The model to train and ensemble. + save_schedule (list[int]): The epochs at which to save the model. + If save schedule is None, save the model at every epoch. + Defaults to None. + """ + super().__init__() + self.model = model + self.save_schedule = save_schedule + + self.saved_models = [] + + @torch.no_grad() + def save_model(self, epoch: int) -> None: + """Save the model at the given epoch if included in the schedule. + + Args: + epoch (int): The current epoch. + """ + if self.save_schedule is None or epoch in self.save_schedule: + self.saved_models.append(copy.deepcopy(self.model)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Subclasses must implement this method.") + + +class TrajectoryEnsemble(TrajectoryModel): + def __init__( + self, + model: nn.Module, + save_schedule: list[int] | None = None, + use_final_checkpoint: bool = True, + ) -> None: + """Ensemble of models at different points in the training trajectory. + + Args: + model (nn.Module): The model to train and ensemble. + save_schedule (list[int]): The epochs at which to save the model. + If save schedule is None, save the model at every epoch. + Defaults to None. + use_final_checkpoint (bool, optional): Whether to use the final + model as a checkpoint. Defaults to True. + """ + super().__init__(model, save_schedule) + self.use_final_checkpoint = use_final_checkpoint + self.num_estimators = int(use_final_checkpoint) + + @torch.no_grad() + def save_model(self, epoch: int) -> None: + """Save the model at the given epoch if included in the schedule. + + Args: + epoch (int): The current epoch. + """ + if self.save_schedule is None or epoch in self.save_schedule: + self.saved_models.append(copy.deepcopy(self.model)) + self.num_estimators += 1 + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for evaluation. + + If the model is in evaluation mode, this method will return the + ensemble prediction. Otherwise, it will return the prediction of the + current model. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The model or ensemble output. + """ + if not len(self.saved_models): + return self.model.forward(x) + preds = torch.cat( + [model.forward(x) for model in self.saved_models], dim=0 + ) + if self.use_final_checkpoint: + model_forward = self.model.forward(x) + preds = torch.cat([model_forward, preds], dim=0) + return preds + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) + + +def trajectory_ensemble( + model: nn.Module, + save_schedule: list[int], +) -> TrajectoryEnsemble: + return TrajectoryEnsemble(model, save_schedule) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 90f3d2ba..edf38e04 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -33,6 +33,7 @@ RiskAt80Cov, VariationRatio, ) +from torch_uncertainty.models import TrajectoryEnsemble, TrajectoryModel from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup from torch_uncertainty.utils import csv_writer, plot_hist @@ -138,6 +139,8 @@ def __init__( self.save_in_csv = save_in_csv self.calibration_set = calibration_set self.binary_cls = num_classes == 1 + self.is_trajectory_ensemble = isinstance(model, TrajectoryEnsemble) + self.is_trajectory_model = isinstance(model, TrajectoryModel) self.model = model self.loss = loss @@ -197,7 +200,7 @@ def __init__( self.test_ood_entropy = Entropy() # metrics for ensembles only - if self.num_estimators > 1: + if self.num_estimators > 1 or self.is_trajectory_ensemble: ens_metrics = MetricCollection( { "Disagreement": Disagreement(), @@ -211,10 +214,8 @@ def __init__( if self.eval_ood: self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") - # Mixup - if num_estimators == 1: - self.mixup = self.init_mixup(mixup_params) + self.mixup = self._init_mixup(mixup_params) if self.eval_grouping_loss: grouping_loss = MetricCollection( @@ -231,7 +232,7 @@ def __init__( self.id_logit_storage = None self.ood_logit_storage = None - def init_mixup(self, mixup_params: dict | None) -> Callable: + def _init_mixup(self, mixup_params: dict | None) -> Callable: if mixup_params is None: mixup_params = {} mixup_params = MIXUP_PARAMS | mixup_params @@ -488,6 +489,12 @@ def test_step( if self.ood_logit_storage is not None: self.ood_logit_storage.append(logits.detach().cpu()) + def on_train_epoch_end(self) -> None: + if self.is_trajectory_model: + self.model.save_model(self.current_epoch) + if self.is_trajectory_ensemble: + self.num_estimators = self.model.num_estimators + def on_validation_epoch_end(self) -> None: self.log_dict(self.val_cls_metrics.compute(), sync_dist=True) self.val_cls_metrics.reset() From 8e7c18809e06a5f406954e2770aaa5581d0464b3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Jun 2024 22:53:13 +0200 Subject: [PATCH 12/57] :sparkles: Add EMA & SWA & Reformat models --- .../mnist/configs/lenet_ema.yaml | 61 ++++++ .../mnist/configs/lenet_swa.yaml | 62 ++++++ tests/_dummies/baseline.py | 2 +- tests/models/test_mc_dropout.py | 6 +- torch_uncertainty/models/__init__.py | 10 +- torch_uncertainty/models/trajectory_models.py | 106 ---------- torch_uncertainty/models/wrappers/__init__.py | 9 + .../models/{ => wrappers}/deep_ensembles.py | 0 .../models/{ => wrappers}/mc_dropout.py | 6 +- .../models/wrappers/trajectory_models.py | 185 ++++++++++++++++++ torch_uncertainty/routines/classification.py | 25 ++- 11 files changed, 349 insertions(+), 123 deletions(-) create mode 100644 experiments/classification/mnist/configs/lenet_ema.yaml create mode 100644 experiments/classification/mnist/configs/lenet_swa.yaml delete mode 100644 torch_uncertainty/models/trajectory_models.py create mode 100644 torch_uncertainty/models/wrappers/__init__.py rename torch_uncertainty/models/{ => wrappers}/deep_ensembles.py (100%) rename torch_uncertainty/models/{ => wrappers}/mc_dropout.py (98%) create mode 100644 torch_uncertainty/models/wrappers/trajectory_models.py diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml new file mode 100644 index 00000000..5ec2dddc --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -0,0 +1,61 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_ema + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.EMA + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + momentum: 0.9 + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml new file mode 100644 index 00000000..82b6e66d --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -0,0 +1,62 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_swa + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.SWA + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + cycle_start: 9 + cycle_length: 5 + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 5a50cb5e..22cf6e5c 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -7,7 +7,7 @@ NormalInverseGammaLayer, NormalLayer, ) -from torch_uncertainty.models.deep_ensembles import deep_ensembles +from torch_uncertainty.models import deep_ensembles from torch_uncertainty.routines import ( ClassificationRoutine, PixelRegressionRoutine, diff --git a/tests/models/test_mc_dropout.py b/tests/models/test_mc_dropout.py index 69e8e2e5..23a70c6a 100644 --- a/tests/models/test_mc_dropout.py +++ b/tests/models/test_mc_dropout.py @@ -2,7 +2,7 @@ import torch from tests._dummies.model import dummy_model -from torch_uncertainty.models.mc_dropout import _MCDropout, mc_dropout +from torch_uncertainty.models import MCDropout, mc_dropout class TestMCDropout: @@ -36,12 +36,12 @@ def test_mc_dropout_errors(self): model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): - _MCDropout( + MCDropout( model=model, num_estimators=-1, last_layer=True, on_batch=True ) with pytest.raises(ValueError): - _MCDropout( + MCDropout( model=model, num_estimators=0, last_layer=False, on_batch=False ) diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 1025c613..33af2297 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,8 +1,10 @@ # ruff: noqa: F401 -from .deep_ensembles import deep_ensembles -from .mc_dropout import mc_dropout -from .trajectory_models import ( +from .wrappers import ( + EMA, + SWA, + MCDropout, TrajectoryEnsemble, TrajectoryModel, - trajectory_ensemble, + deep_ensembles, + mc_dropout, ) diff --git a/torch_uncertainty/models/trajectory_models.py b/torch_uncertainty/models/trajectory_models.py deleted file mode 100644 index 8fcb32f9..00000000 --- a/torch_uncertainty/models/trajectory_models.py +++ /dev/null @@ -1,106 +0,0 @@ -import copy - -import torch -from torch import nn - - -class TrajectoryModel(nn.Module): - def __init__( - self, - model: nn.Module, - save_schedule: list[int] | None = None, - ) -> None: - """Ensemble of models at different points in the training trajectory. - - Args: - model (nn.Module): The model to train and ensemble. - save_schedule (list[int]): The epochs at which to save the model. - If save schedule is None, save the model at every epoch. - Defaults to None. - """ - super().__init__() - self.model = model - self.save_schedule = save_schedule - - self.saved_models = [] - - @torch.no_grad() - def save_model(self, epoch: int) -> None: - """Save the model at the given epoch if included in the schedule. - - Args: - epoch (int): The current epoch. - """ - if self.save_schedule is None or epoch in self.save_schedule: - self.saved_models.append(copy.deepcopy(self.model)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("Subclasses must implement this method.") - - -class TrajectoryEnsemble(TrajectoryModel): - def __init__( - self, - model: nn.Module, - save_schedule: list[int] | None = None, - use_final_checkpoint: bool = True, - ) -> None: - """Ensemble of models at different points in the training trajectory. - - Args: - model (nn.Module): The model to train and ensemble. - save_schedule (list[int]): The epochs at which to save the model. - If save schedule is None, save the model at every epoch. - Defaults to None. - use_final_checkpoint (bool, optional): Whether to use the final - model as a checkpoint. Defaults to True. - """ - super().__init__(model, save_schedule) - self.use_final_checkpoint = use_final_checkpoint - self.num_estimators = int(use_final_checkpoint) - - @torch.no_grad() - def save_model(self, epoch: int) -> None: - """Save the model at the given epoch if included in the schedule. - - Args: - epoch (int): The current epoch. - """ - if self.save_schedule is None or epoch in self.save_schedule: - self.saved_models.append(copy.deepcopy(self.model)) - self.num_estimators += 1 - - def eval_forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass for evaluation. - - If the model is in evaluation mode, this method will return the - ensemble prediction. Otherwise, it will return the prediction of the - current model. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The model or ensemble output. - """ - if not len(self.saved_models): - return self.model.forward(x) - preds = torch.cat( - [model.forward(x) for model in self.saved_models], dim=0 - ) - if self.use_final_checkpoint: - model_forward = self.model.forward(x) - preds = torch.cat([model_forward, preds], dim=0) - return preds - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training: - return self.model.forward(x) - return self.eval_forward(x) - - -def trajectory_ensemble( - model: nn.Module, - save_schedule: list[int], -) -> TrajectoryEnsemble: - return TrajectoryEnsemble(model, save_schedule) diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py new file mode 100644 index 00000000..ce0e005a --- /dev/null +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -0,0 +1,9 @@ +# ruff: noqa: F401 +from .deep_ensembles import deep_ensembles +from .mc_dropout import MCDropout, mc_dropout +from .trajectory_models import ( + EMA, + SWA, + TrajectoryEnsemble, + TrajectoryModel, +) diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py similarity index 100% rename from torch_uncertainty/models/deep_ensembles.py rename to torch_uncertainty/models/wrappers/deep_ensembles.py diff --git a/torch_uncertainty/models/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py similarity index 98% rename from torch_uncertainty/models/mc_dropout.py rename to torch_uncertainty/models/wrappers/mc_dropout.py index ff0d389a..78c7c1d6 100644 --- a/torch_uncertainty/models/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -2,7 +2,7 @@ from torch import Tensor, nn -class _MCDropout(nn.Module): +class MCDropout(nn.Module): def __init__( self, model: nn.Module, @@ -99,7 +99,7 @@ def mc_dropout( num_estimators: int, last_layer: bool = False, on_batch: bool = True, -) -> _MCDropout: +) -> MCDropout: """MC Dropout wrapper for a model. Args: @@ -112,7 +112,7 @@ def mc_dropout( to true. """ - return _MCDropout( + return MCDropout( model=model, num_estimators=num_estimators, last_layer=last_layer, diff --git a/torch_uncertainty/models/wrappers/trajectory_models.py b/torch_uncertainty/models/wrappers/trajectory_models.py new file mode 100644 index 00000000..a5609fe8 --- /dev/null +++ b/torch_uncertainty/models/wrappers/trajectory_models.py @@ -0,0 +1,185 @@ +import copy + +import torch +from torch import nn + + +class TrajectoryModel(nn.Module): + def __init__( + self, + model: nn.Module, + ) -> None: + """Updated model at different points in the training trajectory. + + Args: + model (nn.Module): The inner model. + """ + super().__init__() + self.model = model + self.num_estimators = 1 + + def update_model(self, epoch: int) -> None: + """Update the model. + + Args: + epoch (int): The current epoch. + """ + raise NotImplementedError("Subclasses must implement this method.") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Subclasses must implement this method.") + + +class TrajectoryEnsemble(TrajectoryModel): + def __init__( + self, + model: nn.Module, + save_schedule: list[int] | None = None, + use_final_checkpoint: bool = True, + ) -> None: + """Ensemble of models at different points in the training trajectory. + + Args: + model (nn.Module): The model to train and ensemble. + save_schedule (list[int]): The epochs at which to save the model. + If save schedule is None, save the model at every epoch. + Defaults to None. + use_final_checkpoint (bool, optional): Whether to use the final + model as a checkpoint. Defaults to True. + """ + super().__init__(model) + self.save_schedule = save_schedule + self.use_final_checkpoint = use_final_checkpoint + self.num_estimators = int(use_final_checkpoint) + self.saved_models = [] + + @torch.no_grad() + def update_model(self, epoch: int) -> None: + """Save the model at the given epoch if included in the schedule. + + Args: + epoch (int): The current epoch. + """ + if self.save_schedule is None or epoch in self.save_schedule: + self.saved_models.append(copy.deepcopy(self.model)) + self.num_estimators += 1 + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for evaluation. + + If the model is in evaluation mode, this method will return the + ensemble prediction. Otherwise, it will return the prediction of the + current model. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The model or ensemble output. + """ + if not len(self.saved_models): + return self.model.forward(x) + preds = torch.cat( + [model.forward(x) for model in self.saved_models], dim=0 + ) + if self.use_final_checkpoint: + model_forward = self.model.forward(x) + preds = torch.cat([model_forward, preds], dim=0) + return preds + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) + + +class EMA(TrajectoryModel): + def __init__( + self, + model: nn.Module, + momentum: float, + ) -> None: + """Exponential moving average model. + + Args: + model (nn.Module): The model to train and ensemble. + momentum (float): The momentum of the moving average. + """ + super().__init__(model) + self.ema_model = None + self.momentum = momentum + self.remainder = 1 - momentum + + def update_model(self, epoch: int) -> None: + """Update the EMA model. + + Args: + epoch (int): The current epoch. For API consistency. + """ + if self.ema_model is None: + self.ema_model = copy.deepcopy(self.model) + else: + for ema_param, param in zip( + self.ema_model.parameters(), + self.model.parameters(), + strict=False, + ): + ema_param.data = ( + ema_param.data * self.momentum + param.data * self.remainder + ) + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + if self.ema_model is None: + return self.model.forward(x) + return self.ema_model.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) + + +class SWA(TrajectoryModel): + def __init__( + self, + model: nn.Module, + cycle_start: int, + cycle_length: int, + ) -> None: + super().__init__(model) + self.cycle_start = cycle_start + self.cycle_length = cycle_length + self.num_averaged = 0 + self.swa_model = None + self.need_bn_update = False + + @torch.no_grad() + def update_model(self, epoch: int) -> None: + if ( + epoch >= self.cycle_start + and (epoch - self.cycle_start) % self.cycle_length == 0 + ): + if self.swa_model is None: + self.swa_model = copy.deepcopy(self.model) + self.num_averaged = 1 + else: + for swa_param, param in zip( + self.swa_model.parameters(), + self.model.parameters(), + strict=False, + ): + swa_param.data += (param.data - swa_param.data) / ( + self.num_averaged + 1 + ) + self.num_averaged += 1 + self.need_bn_update = True + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + if self.swa_model is None: + return self.model.forward(x) + return self.swa_model.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index edf38e04..d5f948b0 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -332,9 +332,10 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: logits = self.model(inputs) return logits - def apply_mixup( + def _apply_mixup( self, batch: tuple[Tensor, Tensor] ) -> tuple[Tensor, Tensor]: + # Mixup only for single models if self.num_estimators == 1: if self.mixup_params["mixtype"] == "kernel_warping": if self.mixup_params["dist_sim"] == "emb": @@ -351,8 +352,7 @@ def apply_mixup( def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - # Mixup only for single models - batch = self.apply_mixup(batch) + batch = self._apply_mixup(batch) inputs, target = self.format_batch_fn(batch) if self.is_elbo: @@ -489,9 +489,22 @@ def test_step( if self.ood_logit_storage is not None: self.ood_logit_storage.append(logits.detach().cpu()) - def on_train_epoch_end(self) -> None: - if self.is_trajectory_model: - self.model.save_model(self.current_epoch) + def on_validation_epoch_start(self) -> None: + if ( + self.is_trajectory_model and self.current_epoch > 0 + ): # workaround of sanity checks + self.model.update_model(self.current_epoch) + if ( + hasattr(self.model, "need_bn_update") + and self.model.need_bn_update + ): + torch.optim.swa_utils.update_bn( + self.trainer.train_dataloader, + self.model, + device=self.device, + ) + self.model.need_bn_update = False + if self.is_trajectory_ensemble: self.num_estimators = self.model.num_estimators From 8b0a02aece959cc1696b70f69d5dd5aa844d4357 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Jun 2024 22:59:14 +0200 Subject: [PATCH 13/57] :book: Add SWA to docs --- README.md | 1 + docs/source/api.rst | 15 +++++---------- docs/source/references.rst | 9 +++++++++ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 00d0bc2a..2f96347d 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ To date, the following deep learning baselines have been implemented: - MIMO - Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) +- Stochastic Weight Averaging - Regression with Beta Gaussian NLL Loss - Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) diff --git a/docs/source/api.rst b/docs/source/api.rst index d4f99acf..538181a4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -153,8 +153,8 @@ Models .. currentmodule:: torch_uncertainty.models -Deep Ensembles -^^^^^^^^^^^^^^ +Wrappers +^^^^^^^^ .. autosummary:: :toctree: generated/ @@ -162,14 +162,9 @@ Deep Ensembles :template: class.rst deep_ensembles - -Monte Carlo Dropout - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: class.rst - + EMA + SWA + MC_Dropout mc_dropout Metrics diff --git a/docs/source/references.rst b/docs/source/references.rst index bd4467c9..34219386 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -73,6 +73,15 @@ For Monte-Carlo Dropout, consider citing: * Authors: *Yarin Gal and Zoubin Ghahramani* * Paper: `ICML 2016 `__. +Stochastic Weight Averaging +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For Stochastic Weight Averaging, consider citing: + +**Averaging Weights Leads to Wider Optima and Better Generalization** + +* Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson* +* Paper: `UAI 2018 `__. BatchEnsemble ^^^^^^^^^^^^^ From 3c231e2ea615f5a4d2299e1a0341798abe03a811 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Jun 2024 23:45:37 +0200 Subject: [PATCH 14/57] :hammer: Refactor EMA, SWA, & Checkpoint Ens. --- docs/source/api.rst | 1 + docs/source/references.rst | 10 + ...ry.yaml => lenet_checkpoint_ensemble.yaml} | 2 +- .../mnist/configs/lenet_ema.yaml | 2 +- .../mnist/configs/lenet_swa.yaml | 2 +- torch_uncertainty/models/__init__.py | 5 +- torch_uncertainty/models/wrappers/__init__.py | 14 +- .../models/wrappers/checkpoint_ensemble.py | 73 +++++++ torch_uncertainty/models/wrappers/ema.py | 53 +++++ torch_uncertainty/models/wrappers/swa.py | 63 ++++++ .../models/wrappers/trajectory_models.py | 185 ------------------ torch_uncertainty/routines/classification.py | 57 +++--- 12 files changed, 244 insertions(+), 223 deletions(-) rename experiments/classification/mnist/configs/{lenet_trajectory.yaml => lenet_checkpoint_ensemble.yaml} (96%) create mode 100644 torch_uncertainty/models/wrappers/checkpoint_ensemble.py create mode 100644 torch_uncertainty/models/wrappers/ema.py create mode 100644 torch_uncertainty/models/wrappers/swa.py delete mode 100644 torch_uncertainty/models/wrappers/trajectory_models.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 538181a4..46a1828d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -162,6 +162,7 @@ Wrappers :template: class.rst deep_ensembles + CheckpointEnsemble EMA SWA MC_Dropout diff --git a/docs/source/references.rst b/docs/source/references.rst index 34219386..ec0694b0 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -83,6 +83,16 @@ For Stochastic Weight Averaging, consider citing: * Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson* * Paper: `UAI 2018 `__. +CheckpointEnsemble +^^^^^^^^^^^^^^^^^^ + +For CheckpointEnsemble, consider citing: + +**Checkpoint Ensembles: Ensemble Methods from a Single Training Process** + +* Authors: *Hugh Chen, Scott Lundberg, Su-In Lee* +* Paper: `AAAI 2018 `__. + BatchEnsemble ^^^^^^^^^^^^^ diff --git a/experiments/classification/mnist/configs/lenet_trajectory.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml similarity index 96% rename from experiments/classification/mnist/configs/lenet_trajectory.yaml rename to experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index caa41348..e92e6ca7 100644 --- a/experiments/classification/mnist/configs/lenet_trajectory.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -28,7 +28,7 @@ trainer: check_finite: true model: model: - class_path: torch_uncertainty.models.TrajectoryEnsemble + class_path: torch_uncertainty.models.CheckpointEnsemble init_args: model: class_path: torch_uncertainty.models.lenet._LeNet diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 5ec2dddc..363461c6 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -43,7 +43,7 @@ model: dropout_rate: 0 last_layer_dropout: false layer_args: {} - momentum: 0.9 + momentum: 0.99 num_classes: 10 loss: CrossEntropyLoss data: diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 82b6e66d..b79647c1 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -43,7 +43,7 @@ model: dropout_rate: 0 last_layer_dropout: false layer_args: {} - cycle_start: 9 + cycle_start: 19 cycle_length: 5 num_classes: 10 loss: CrossEntropyLoss diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 33af2297..82b14a7f 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,10 +1,11 @@ # ruff: noqa: F401 from .wrappers import ( EMA, + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, SWA, + CheckpointEnsemble, MCDropout, - TrajectoryEnsemble, - TrajectoryModel, deep_ensembles, mc_dropout, ) diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index ce0e005a..4299d3b0 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -1,9 +1,11 @@ # ruff: noqa: F401 +from .checkpoint_ensemble import ( + CheckpointEnsemble, +) from .deep_ensembles import deep_ensembles +from .ema import EMA from .mc_dropout import MCDropout, mc_dropout -from .trajectory_models import ( - EMA, - SWA, - TrajectoryEnsemble, - TrajectoryModel, -) +from .swa import SWA + +STEP_UPDATE_MODEL = (EMA,) +EPOCH_UPDATE_MODEL = (SWA, CheckpointEnsemble) diff --git a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py new file mode 100644 index 00000000..125e8bd0 --- /dev/null +++ b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py @@ -0,0 +1,73 @@ +import copy + +import torch +from torch import nn + + +class CheckpointEnsemble(nn.Module): + def __init__( + self, + model: nn.Module, + save_schedule: list[int] | None = None, + use_final_checkpoint: bool = True, + ) -> None: + """Ensemble of models at different points in the training trajectory. + + Args: + model (nn.Module): The model to train and ensemble. + save_schedule (list[int]): The epochs at which to save the model. + If save schedule is None, save the model at every epoch. + Defaults to None. + use_final_checkpoint (bool, optional): Whether to use the final + model as a checkpoint. Defaults to True. + + Reference: + Checkpoint Ensembles: Ensemble Methods from a Single Training Process. + Hugh Chen, Scott Lundberg, Su-In Lee. In AAAI 2018. + """ + super().__init__() + self.model = model + self.save_schedule = save_schedule + self.use_final_checkpoint = use_final_checkpoint + self.num_estimators = int(use_final_checkpoint) + self.saved_models = [] + self.num_estimators = 1 + + @torch.no_grad() + def update_model(self, epoch: int) -> None: + """Save the model at the given epoch if included in the schedule. + + Args: + epoch (int): The current epoch. + """ + if self.save_schedule is None or epoch in self.save_schedule: + self.saved_models.append(copy.deepcopy(self.model)) + self.num_estimators += 1 + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for evaluation. + + If the model is in evaluation mode, this method will return the + ensemble prediction. Otherwise, it will return the prediction of the + current model. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The model or ensemble output. + """ + if not len(self.saved_models): + return self.model.forward(x) + preds = torch.cat( + [model.forward(x) for model in self.saved_models], dim=0 + ) + if self.use_final_checkpoint: + model_forward = self.model.forward(x) + preds = torch.cat([model_forward, preds], dim=0) + return preds + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) diff --git a/torch_uncertainty/models/wrappers/ema.py b/torch_uncertainty/models/wrappers/ema.py new file mode 100644 index 00000000..b6d7476e --- /dev/null +++ b/torch_uncertainty/models/wrappers/ema.py @@ -0,0 +1,53 @@ +import copy + +from torch import Tensor, nn + + +class EMA(nn.Module): + def __init__( + self, + model: nn.Module, + momentum: float, + ) -> None: + """Exponential moving average model. + + Args: + model (nn.Module): The model to train and ensemble. + momentum (float): The momentum of the moving average. + """ + super().__init__() + _ema_checks(momentum) + self.model = model + self.ema_model = copy.deepcopy(model) + self.momentum = momentum + self.remainder = 1 - momentum + + def update_model(self, epoch: int) -> None: + """Update the EMA model. + + Args: + epoch (int): The current epoch. For API consistency. + """ + for ema_param, param in zip( + self.ema_model.parameters(), + self.model.parameters(), + strict=False, + ): + ema_param.data = ( + ema_param.data * self.momentum + param.data * self.remainder + ) + + def eval_forward(self, x: Tensor) -> Tensor: + return self.ema_model.forward(x) + + def forward(self, x: Tensor) -> Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) + + +def _ema_checks(momentum: float) -> None: + if momentum < 0.0 or momentum >= 1.0: + raise ValueError( + f"`momentum` must be in the range [0, 1). Got {momentum}." + ) diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py new file mode 100644 index 00000000..ad9acc15 --- /dev/null +++ b/torch_uncertainty/models/wrappers/swa.py @@ -0,0 +1,63 @@ +import copy + +import torch +from torch import nn + + +class SWA(nn.Module): + def __init__( + self, + model: nn.Module, + cycle_start: int, + cycle_length: int, + ) -> None: + super().__init__() + _swa_checks(cycle_start, cycle_length) + self.model = model + self.cycle_start = cycle_start + self.cycle_length = cycle_length + self.num_averaged = 0 + self.swa_model = None + self.need_bn_update = False + + @torch.no_grad() + def update_model(self, epoch: int) -> None: + if ( + epoch >= self.cycle_start + and (epoch - self.cycle_start) % self.cycle_length == 0 + ): + if self.swa_model is None: + self.swa_model = copy.deepcopy(self.model) + self.num_averaged = 1 + else: + for swa_param, param in zip( + self.swa_model.parameters(), + self.model.parameters(), + strict=False, + ): + swa_param.data += (param.data - swa_param.data) / ( + self.num_averaged + 1 + ) + self.num_averaged += 1 + self.need_bn_update = True + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + if self.swa_model is None: + return self.model.forward(x) + return self.swa_model.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) + + +def _swa_checks(cycle_start: int, cycle_length: int) -> None: + if cycle_start < 0: + raise ValueError( + f"`cycle_start` must be non-negative. Got {cycle_start}." + ) + if cycle_length <= 0: + raise ValueError( + f"`cycle_length` must be strictly positive. Got {cycle_length}." + ) diff --git a/torch_uncertainty/models/wrappers/trajectory_models.py b/torch_uncertainty/models/wrappers/trajectory_models.py deleted file mode 100644 index a5609fe8..00000000 --- a/torch_uncertainty/models/wrappers/trajectory_models.py +++ /dev/null @@ -1,185 +0,0 @@ -import copy - -import torch -from torch import nn - - -class TrajectoryModel(nn.Module): - def __init__( - self, - model: nn.Module, - ) -> None: - """Updated model at different points in the training trajectory. - - Args: - model (nn.Module): The inner model. - """ - super().__init__() - self.model = model - self.num_estimators = 1 - - def update_model(self, epoch: int) -> None: - """Update the model. - - Args: - epoch (int): The current epoch. - """ - raise NotImplementedError("Subclasses must implement this method.") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("Subclasses must implement this method.") - - -class TrajectoryEnsemble(TrajectoryModel): - def __init__( - self, - model: nn.Module, - save_schedule: list[int] | None = None, - use_final_checkpoint: bool = True, - ) -> None: - """Ensemble of models at different points in the training trajectory. - - Args: - model (nn.Module): The model to train and ensemble. - save_schedule (list[int]): The epochs at which to save the model. - If save schedule is None, save the model at every epoch. - Defaults to None. - use_final_checkpoint (bool, optional): Whether to use the final - model as a checkpoint. Defaults to True. - """ - super().__init__(model) - self.save_schedule = save_schedule - self.use_final_checkpoint = use_final_checkpoint - self.num_estimators = int(use_final_checkpoint) - self.saved_models = [] - - @torch.no_grad() - def update_model(self, epoch: int) -> None: - """Save the model at the given epoch if included in the schedule. - - Args: - epoch (int): The current epoch. - """ - if self.save_schedule is None or epoch in self.save_schedule: - self.saved_models.append(copy.deepcopy(self.model)) - self.num_estimators += 1 - - def eval_forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass for evaluation. - - If the model is in evaluation mode, this method will return the - ensemble prediction. Otherwise, it will return the prediction of the - current model. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The model or ensemble output. - """ - if not len(self.saved_models): - return self.model.forward(x) - preds = torch.cat( - [model.forward(x) for model in self.saved_models], dim=0 - ) - if self.use_final_checkpoint: - model_forward = self.model.forward(x) - preds = torch.cat([model_forward, preds], dim=0) - return preds - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training: - return self.model.forward(x) - return self.eval_forward(x) - - -class EMA(TrajectoryModel): - def __init__( - self, - model: nn.Module, - momentum: float, - ) -> None: - """Exponential moving average model. - - Args: - model (nn.Module): The model to train and ensemble. - momentum (float): The momentum of the moving average. - """ - super().__init__(model) - self.ema_model = None - self.momentum = momentum - self.remainder = 1 - momentum - - def update_model(self, epoch: int) -> None: - """Update the EMA model. - - Args: - epoch (int): The current epoch. For API consistency. - """ - if self.ema_model is None: - self.ema_model = copy.deepcopy(self.model) - else: - for ema_param, param in zip( - self.ema_model.parameters(), - self.model.parameters(), - strict=False, - ): - ema_param.data = ( - ema_param.data * self.momentum + param.data * self.remainder - ) - - def eval_forward(self, x: torch.Tensor) -> torch.Tensor: - if self.ema_model is None: - return self.model.forward(x) - return self.ema_model.forward(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training: - return self.model.forward(x) - return self.eval_forward(x) - - -class SWA(TrajectoryModel): - def __init__( - self, - model: nn.Module, - cycle_start: int, - cycle_length: int, - ) -> None: - super().__init__(model) - self.cycle_start = cycle_start - self.cycle_length = cycle_length - self.num_averaged = 0 - self.swa_model = None - self.need_bn_update = False - - @torch.no_grad() - def update_model(self, epoch: int) -> None: - if ( - epoch >= self.cycle_start - and (epoch - self.cycle_start) % self.cycle_length == 0 - ): - if self.swa_model is None: - self.swa_model = copy.deepcopy(self.model) - self.num_averaged = 1 - else: - for swa_param, param in zip( - self.swa_model.parameters(), - self.model.parameters(), - strict=False, - ): - swa_param.data += (param.data - swa_param.data) / ( - self.num_averaged + 1 - ) - self.num_averaged += 1 - self.need_bn_update = True - - def eval_forward(self, x: torch.Tensor) -> torch.Tensor: - if self.swa_model is None: - return self.model.forward(x) - return self.swa_model.forward(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training: - return self.model.forward(x) - return self.eval_forward(x) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index d5f948b0..40d50e3d 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -33,7 +33,11 @@ RiskAt80Cov, VariationRatio, ) -from torch_uncertainty.models import TrajectoryEnsemble, TrajectoryModel +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, + CheckpointEnsemble, +) from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup from torch_uncertainty.utils import csv_writer, plot_hist @@ -139,8 +143,8 @@ def __init__( self.save_in_csv = save_in_csv self.calibration_set = calibration_set self.binary_cls = num_classes == 1 - self.is_trajectory_ensemble = isinstance(model, TrajectoryEnsemble) - self.is_trajectory_model = isinstance(model, TrajectoryModel) + self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) self.model = model self.loss = loss @@ -200,7 +204,7 @@ def __init__( self.test_ood_entropy = Entropy() # metrics for ensembles only - if self.num_estimators > 1 or self.is_trajectory_ensemble: + if self.num_estimators > 1 or isinstance(model, CheckpointEnsemble): ens_metrics = MetricCollection( { "Disagreement": Disagreement(), @@ -289,6 +293,25 @@ def on_train_start(self) -> None: self.hparams, ) + def on_validation_start(self) -> None: + if ( + self.need_epoch_update and self.current_epoch > 0 + ): # workaround of sanity checks + self.model.update_model(self.current_epoch) + if ( + hasattr(self.model, "need_bn_update") + and self.model.need_bn_update + ): + torch.optim.swa_utils.update_bn( + self.trainer.train_dataloader, + self.model, + device=self.device, + ) + self.model.need_bn_update = False + + if isinstance(self.model, CheckpointEnsemble): + self.num_estimators = self.model.num_estimators + def on_test_start(self) -> None: if isinstance(self.calibration_set, str) and self.calibration_set in [ "val", @@ -368,7 +391,8 @@ def training_step( loss = self.loss(logits, target) else: loss = self.loss(logits, target, self.current_epoch) - + if self.need_step_update: + self.model.update_model(self.current_epoch) self.log("train_loss", loss) return loss @@ -376,9 +400,7 @@ def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, target = batch - logits = self.forward( - inputs, save_feats=self.eval_grouping_loss - ) # (m*b, c) + logits = self.forward(inputs, save_feats=self.eval_grouping_loss) logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) if self.binary_cls: @@ -489,25 +511,6 @@ def test_step( if self.ood_logit_storage is not None: self.ood_logit_storage.append(logits.detach().cpu()) - def on_validation_epoch_start(self) -> None: - if ( - self.is_trajectory_model and self.current_epoch > 0 - ): # workaround of sanity checks - self.model.update_model(self.current_epoch) - if ( - hasattr(self.model, "need_bn_update") - and self.model.need_bn_update - ): - torch.optim.swa_utils.update_bn( - self.trainer.train_dataloader, - self.model, - device=self.device, - ) - self.model.need_bn_update = False - - if self.is_trajectory_ensemble: - self.num_estimators = self.model.num_estimators - def on_validation_epoch_end(self) -> None: self.log_dict(self.val_cls_metrics.compute(), sync_dist=True) self.val_cls_metrics.reset() From 1f72eadcbf61e14f9861df37d2edf96d1e297732 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Jun 2024 23:57:50 +0200 Subject: [PATCH 15/57] :book: Fix conf error --- docs/source/references.rst | 2 +- torch_uncertainty/models/wrappers/checkpoint_ensemble.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/references.rst b/docs/source/references.rst index ec0694b0..0273fd41 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -91,7 +91,7 @@ For CheckpointEnsemble, consider citing: **Checkpoint Ensembles: Ensemble Methods from a Single Training Process** * Authors: *Hugh Chen, Scott Lundberg, Su-In Lee* -* Paper: `AAAI 2018 `__. +* Paper: `ArXiv `__. BatchEnsemble ^^^^^^^^^^^^^ diff --git a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py index 125e8bd0..5b446882 100644 --- a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py +++ b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py @@ -23,7 +23,7 @@ def __init__( Reference: Checkpoint Ensembles: Ensemble Methods from a Single Training Process. - Hugh Chen, Scott Lundberg, Su-In Lee. In AAAI 2018. + Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018. """ super().__init__() self.model = model From c091c9a12faf5fdd1fdb83ee40fcec776ecd6bce Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 13 Jun 2024 09:47:36 +0200 Subject: [PATCH 16/57] :shirt: Small changes --- README.md | 1 + torch_uncertainty/models/segmentation/segformer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2f96347d..af96b2e8 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ To date, the following deep learning baselines have been implemented: - MIMO - Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) +- Checkpoint Ensembles - Stochastic Weight Averaging - Regression with Beta Gaussian NLL Loss - Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) diff --git a/torch_uncertainty/models/segmentation/segformer.py b/torch_uncertainty/models/segmentation/segformer.py index 763aea71..6c34dfcb 100644 --- a/torch_uncertainty/models/segmentation/segformer.py +++ b/torch_uncertainty/models/segmentation/segformer.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.layers import DropPath, to_2tuple, trunc_normal_ from torch import Tensor, nn From 4bfd35132c8bdb97dc9c7b26d65df1eb9eba89a0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 13 Jun 2024 09:50:27 +0200 Subject: [PATCH 17/57] :hammer: Refactor the post processing methods --- torch_uncertainty/post_processing/__init__.py | 1 + torch_uncertainty/post_processing/abstract.py | 21 +++++++++++++++++++ .../post_processing/calibration/scaler.py | 4 +++- .../post_processing/mc_batch_norm.py | 5 +++-- 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 torch_uncertainty/post_processing/abstract.py diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index edbdceef..793e3637 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 +from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py new file mode 100644 index 00000000..9c002a0c --- /dev/null +++ b/torch_uncertainty/post_processing/abstract.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + +from torch import Tensor, nn +from torch.utils.data import Dataset + + +class PostProcessing(ABC, nn.Module): + def __init__(self): + super().__init__() + self.trained = False + + @abstractmethod + def fit(self, dataset: Dataset) -> None: + pass + + @abstractmethod + def forward( + self, + x: Tensor, + ) -> Tensor: + pass diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index d87730b9..5076c70f 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -5,8 +5,10 @@ from torch.utils.data import DataLoader, Dataset from tqdm import tqdm +from torch_uncertainty.post_processing import PostProcessing -class Scaler(nn.Module): + +class Scaler(PostProcessing): criterion = nn.CrossEntropyLoss() trained = False diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index d99fdd7c..2fa5f5b1 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -6,9 +6,10 @@ from torch.utils.data import DataLoader, Dataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d +from torch_uncertainty.post_processing import PostProcessing -class MCBatchNorm(nn.Module): +class MCBatchNorm(PostProcessing): counter: int = 0 mc_batch_norm_layers: list[MCBatchNorm2d] = [] trained = False @@ -99,7 +100,7 @@ def _est_forward(self, x: Tensor) -> Tensor: def forward( self, x: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tensor: if self.training: return self.model(x) if not self.trained: From be85bf85b942cba52e0e20f988064d5471eafbff Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 13 Jun 2024 09:56:17 +0200 Subject: [PATCH 18/57] :hammer: Refactor the AbstractDatamodule --- tests/_dummies/datamodule.py | 10 +++++----- tests/datamodules/test_abstract_datamodule.py | 12 ++++++------ torch_uncertainty/datamodules/abstract.py | 16 ++++++++++------ .../datamodules/classification/cifar10.py | 4 ++-- .../datamodules/classification/cifar100.py | 4 ++-- .../datamodules/classification/imagenet.py | 4 ++-- .../datamodules/classification/mnist.py | 4 ++-- .../datamodules/classification/tiny_imagenet.py | 4 ++-- torch_uncertainty/datamodules/depth/base.py | 4 ++-- .../datamodules/segmentation/camvid.py | 4 ++-- .../datamodules/segmentation/cityscapes.py | 4 ++-- .../datamodules/segmentation/muad.py | 4 ++-- torch_uncertainty/datamodules/uci_regression.py | 4 ++-- torch_uncertainty/post_processing/__init__.py | 2 +- torch_uncertainty/post_processing/abstract.py | 2 +- .../post_processing/calibration/scaler.py | 4 ++-- .../post_processing/mc_batch_norm.py | 4 ++-- 17 files changed, 47 insertions(+), 43 deletions(-) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 51c769dd..3de1850a 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from torchvision import tv_tensors -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from .dataset import ( DummyClassificationDataset, @@ -17,7 +17,7 @@ ) -class DummyClassificationDataModule(AbstractDataModule): +class DummyClassificationDataModule(BaseDataModule): num_channels = 1 image_size: int = 4 training_task = "classification" @@ -104,7 +104,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyRegressionDataModule(AbstractDataModule): +class DummyRegressionDataModule(BaseDataModule): in_features = 4 training_task = "regression" @@ -160,7 +160,7 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]: return [self._data_loader(self.test)] -class DummySegmentationDataModule(AbstractDataModule): +class DummySegmentationDataModule(BaseDataModule): num_channels = 3 training_task = "segmentation" @@ -249,7 +249,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyDepthDataModule(AbstractDataModule): +class DummyDepthDataModule(BaseDataModule): num_channels = 3 training_task = "pixel_regression" diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 7b0f5e66..e54c7d74 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -4,16 +4,16 @@ from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.datamodules.abstract import ( - AbstractDataModule, + BaseDataModule, CrossValDataModule, ) -class TestAbstractDataModule: - """Testing the AbstractDataModule class.""" +class TestBaseDataModule: + """Testing the BaseDataModule class.""" def test_errors(self): - dm = AbstractDataModule("root", 128, 0.0, 4, True, True) + dm = BaseDataModule("root", 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() dm._get_train_data() @@ -24,7 +24,7 @@ class TestCrossValDataModule: """Testing the CrossValDataModule class.""" def test_cv_main(self): - dm = AbstractDataModule("root", 128, 0.0, 4, True, True) + dm = BaseDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds @@ -46,7 +46,7 @@ def test_cv_main(self): cv_dm.test_dataloader() def test_errors(self): - dm = AbstractDataModule("root", 128, 0.0, 4, True, True) + dm = BaseDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 1da19ced..c3b25cda 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from pathlib import Path from typing import Literal @@ -8,7 +9,7 @@ from torch.utils.data.sampler import SubsetRandomSampler -class AbstractDataModule(LightningDataModule): +class BaseDataModule(ABC, LightningDataModule): training_task: str train: Dataset val: Dataset @@ -47,8 +48,9 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers + @abstractmethod def setup(self, stage: Literal["fit", "test"] | None = None) -> None: - raise NotImplementedError + pass def get_train_set(self) -> Dataset: """Get the training set.""" @@ -113,11 +115,13 @@ def _data_loader( # by setting the correct path to the matrix of data for each dataset. # It is generally "Dataset.samples" or "Dataset.data" # They are used for constructing cross validation splits + @abstractmethod def _get_train_data(self) -> ArrayLike: - raise NotImplementedError + pass + @abstractmethod def _get_train_targets(self) -> ArrayLike: - raise NotImplementedError + pass def make_cross_val_splits( self, n_splits: int = 10, train_over: int = 4 @@ -148,13 +152,13 @@ def make_cross_val_splits( return cv_dm -class CrossValDataModule(AbstractDataModule): +class CrossValDataModule(BaseDataModule): def __init__( self, root: str | Path, train_idx: ArrayLike, val_idx: ArrayLike, - datamodule: AbstractDataModule, + datamodule: BaseDataModule, batch_size: int, val_split: float, num_workers: int, diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 45452115..65d58160 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -9,14 +9,14 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class CIFAR10DataModule(AbstractDataModule): +class CIFAR10DataModule(BaseDataModule): num_classes = 10 num_channels = 3 input_shape = (3, 32, 32) diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index bc5a3691..e6267cde 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -10,14 +10,14 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class CIFAR100DataModule(AbstractDataModule): +class CIFAR100DataModule(BaseDataModule): num_classes = 100 num_channels = 3 input_shape = (3, 32, 32) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d215a79f..d1321a3f 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets.classification import ( ImageNetA, ImageNetO, @@ -23,7 +23,7 @@ ) -class ImageNetDataModule(AbstractDataModule): +class ImageNetDataModule(BaseDataModule): num_classes = 1000 num_channels = 3 test_datasets = ["r", "o", "a"] diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index b18b279d..a71835c2 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -6,13 +6,13 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class MNISTDataModule(AbstractDataModule): +class MNISTDataModule(BaseDataModule): num_classes = 10 num_channels = 1 input_shape = (1, 28, 28) diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 25c62f31..10aab380 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -9,7 +9,7 @@ from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet from torch_uncertainty.utils import ( create_train_val_split, @@ -17,7 +17,7 @@ ) -class TinyImageNetDataModule(AbstractDataModule): +class TinyImageNetDataModule(BaseDataModule): num_classes = 200 num_channels = 3 training_task = "classification" diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 47e8cf73..2337e8e2 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -7,12 +7,12 @@ from torchvision.datasets import VisionDataset from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class DepthDataModule(AbstractDataModule): +class DepthDataModule(BaseDataModule): def __init__( self, dataset: type[VisionDataset], diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 4a4aee65..afa5a6f2 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -4,11 +4,11 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets.segmentation import CamVid -class CamVidDataModule(AbstractDataModule): +class CamVidDataModule(BaseDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index f35bd65d..c6800112 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -6,13 +6,13 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets.segmentation import Cityscapes from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class CityscapesDataModule(AbstractDataModule): +class CityscapesDataModule(BaseDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index c126b05e..37949b1a 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -6,13 +6,13 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import BaseDataModule from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class MUADDataModule(AbstractDataModule): +class MUADDataModule(BaseDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 66571959..2caea314 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -6,10 +6,10 @@ from torch_uncertainty.datasets.regression import UCIRegression -from .abstract import AbstractDataModule +from .abstract import BaseDataModule -class UCIDataModule(AbstractDataModule): +class UCIDataModule(BaseDataModule): training_task = "regression" def __init__( diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index 793e3637..66230e1f 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 -from .abstract import PostProcessing +from .abstract import BasePostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index 9c002a0c..bacde124 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -4,7 +4,7 @@ from torch.utils.data import Dataset -class PostProcessing(ABC, nn.Module): +class BasePostProcessing(ABC, nn.Module): def __init__(self): super().__init__() self.trained = False diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 5076c70f..999bd0da 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -5,10 +5,10 @@ from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from torch_uncertainty.post_processing import PostProcessing +from torch_uncertainty.post_processing import BasePostProcessing -class Scaler(PostProcessing): +class Scaler(BasePostProcessing): criterion = nn.CrossEntropyLoss() trained = False diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 2fa5f5b1..6acb7884 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -6,10 +6,10 @@ from torch.utils.data import DataLoader, Dataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d -from torch_uncertainty.post_processing import PostProcessing +from torch_uncertainty.post_processing import BasePostProcessing -class MCBatchNorm(PostProcessing): +class MCBatchNorm(BasePostProcessing): counter: int = 0 mc_batch_norm_layers: list[MCBatchNorm2d] = [] trained = False From 951ff09319f1a3f0a6c2c403b6abaf7cdd4049cf Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 13 Jun 2024 10:02:55 +0200 Subject: [PATCH 19/57] :bug: Fix test of abstract methods --- tests/datamodules/test_abstract_datamodule.py | 3 +++ torch_uncertainty/datamodules/abstract.py | 6 ++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index e54c7d74..25bb1952 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -13,6 +13,7 @@ class TestBaseDataModule: """Testing the BaseDataModule class.""" def test_errors(self): + BaseDataModule.__abstractmethods__ = set() dm = BaseDataModule("root", 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() @@ -24,6 +25,7 @@ class TestCrossValDataModule: """Testing the CrossValDataModule class.""" def test_cv_main(self): + BaseDataModule.__abstractmethods__ = set() dm = BaseDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds @@ -46,6 +48,7 @@ def test_cv_main(self): cv_dm.test_dataloader() def test_errors(self): + BaseDataModule.__abstractmethods__ = set() dm = BaseDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index c3b25cda..d5a9020f 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -115,13 +115,11 @@ def _data_loader( # by setting the correct path to the matrix of data for each dataset. # It is generally "Dataset.samples" or "Dataset.data" # They are used for constructing cross validation splits - @abstractmethod def _get_train_data(self) -> ArrayLike: - pass + raise NotImplementedError - @abstractmethod def _get_train_targets(self) -> ArrayLike: - pass + raise NotImplementedError def make_cross_val_splits( self, n_splits: int = 10, train_over: int = 4 From 83c4cceb8c4c98d6897aab7f38be18dbf71f6715 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Jun 2024 11:36:39 +0200 Subject: [PATCH 20/57] :hammer: Refactor pp methods --- torch_uncertainty/post_processing/__init__.py | 2 +- torch_uncertainty/post_processing/abstract.py | 8 ++- .../calibration/matrix_scaler.py | 2 +- .../post_processing/calibration/scaler.py | 15 ++-- .../calibration/temperature_scaler.py | 2 +- .../calibration/vector_scaler.py | 2 +- torch_uncertainty/post_processing/laplace.py | 34 +++++++--- .../post_processing/mc_batch_norm.py | 68 ++++++++++++------- 8 files changed, 80 insertions(+), 53 deletions(-) diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index 66230e1f..793e3637 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 -from .abstract import BasePostProcessing +from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index bacde124..9c7908cc 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -4,11 +4,15 @@ from torch.utils.data import Dataset -class BasePostProcessing(ABC, nn.Module): - def __init__(self): +class PostProcessing(ABC, nn.Module): + def __init__(self, model: nn.Module | None = None): super().__init__() + self.model = model self.trained = False + def set_model(self, model: nn.Module) -> None: + self.model = model + @abstractmethod def fit(self, dataset: Dataset) -> None: pass diff --git a/torch_uncertainty/post_processing/calibration/matrix_scaler.py b/torch_uncertainty/post_processing/calibration/matrix_scaler.py index 1899dcbe..a0b2c86e 100644 --- a/torch_uncertainty/post_processing/calibration/matrix_scaler.py +++ b/torch_uncertainty/post_processing/calibration/matrix_scaler.py @@ -9,8 +9,8 @@ class MatrixScaler(Scaler): def __init__( self, - model: nn.Module, num_classes: int, + model: nn.Module | None = None, init_w: float = 1, init_b: float = 0, lr: float = 0.1, diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 999bd0da..3aa435ec 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -5,16 +5,16 @@ from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from torch_uncertainty.post_processing import BasePostProcessing +from torch_uncertainty.post_processing import PostProcessing -class Scaler(BasePostProcessing): +class Scaler(PostProcessing): criterion = nn.CrossEntropyLoss() trained = False def __init__( self, - model: nn.Module, + model: nn.Module | None = None, lr: float = 0.1, max_iter: int = 100, device: Literal["cpu", "cuda"] | device | None = None, @@ -33,8 +33,7 @@ def __init__( Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. On calibration of modern neural networks. In ICML 2017. """ - super().__init__() - self.model = model + super().__init__(model) self.device = device if lr <= 0: @@ -50,7 +49,7 @@ def fit( calibration_set: Dataset, save_logits: bool = False, progress: bool = True, - ) -> "Scaler": + ) -> None: """Fit the temperature parameters to the calibration data. Args: @@ -59,9 +58,6 @@ def fit( labels. Defaults to False. progress (bool, optional): Whether to show a progress bar. Defaults to True. - - Returns: - Scaler: Calibrated scaler. """ logits_list = [] labels_list = [] @@ -91,7 +87,6 @@ def calib_eval() -> float: if save_logits: self.logits = logits self.labels = labels - return self @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index cfd50084..c6372bdb 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -9,7 +9,7 @@ class TemperatureScaler(Scaler): def __init__( self, - model: nn.Module, + model: nn.Module | None = None, init_val: float = 1, lr: float = 0.1, max_iter: int = 100, diff --git a/torch_uncertainty/post_processing/calibration/vector_scaler.py b/torch_uncertainty/post_processing/calibration/vector_scaler.py index 875945c0..53ce9551 100644 --- a/torch_uncertainty/post_processing/calibration/vector_scaler.py +++ b/torch_uncertainty/post_processing/calibration/vector_scaler.py @@ -9,8 +9,8 @@ class VectorScaler(Scaler): def __init__( self, - model: nn.Module, num_classes: int, + model: nn.Module | None = None, init_w: float = 1, init_b: float = 0, lr: float = 0.1, diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index e3f19d13..9a4940e4 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -13,10 +13,10 @@ class Laplace(nn.Module): def __init__( self, - model: nn.Module, task: Literal["classification", "regression"], - subset_of_weights="last_layer", - hessian_structure="kron", + model: nn.Module | None = None, + weight_subset="last_layer", + hessian_struct="kron", pred_type: Literal["glm", "nn"] = "glm", link_approx: Literal[ "mc", "probit", "bridge", "bridge_norm" @@ -27,11 +27,11 @@ def __init__( This class is a wrapper of Laplace classes from the laplace-torch library. Args: - model (nn.Module): model to be converted. task (Literal["classification", "regression"]): task type. - subset_of_weights (str): subset of weights to be considered. Defaults to + model (nn.Module): model to be converted. + weight_subset (str): subset of weights to be considered. Defaults to "last_layer". - hessian_structure (str): structure of the Hessian matrix. Defaults to + hessian_struct (str): structure of the Hessian matrix. Defaults to "kron". pred_type (Literal["glm", "nn"], optional): type of posterior predictive, See the Laplace library for more details. Defaults to "glm". @@ -47,14 +47,26 @@ def __init__( raise ImportError( "The laplace-torch library is not installed. Please install it via `pip install laplace-torch`." ) + + self.pred_type = pred_type + self.link_approx = link_approx + self.task = task + self.weight_subset = weight_subset + self.hessian_struct = hessian_struct + + if model is not None: + self._setup_model(model) + + def _setup_model(self, model) -> None: self.la = Laplace( model=model, - task=task, - subset_of_weights=subset_of_weights, - hessian_structure=hessian_structure, + task=self.task, + weight_subset=self.weight_subset, + hessian_struct=self.hessian_struct, ) - self.pred_type = pred_type - self.link_approx = link_approx + + def set_model(self, model: nn.Module) -> None: + self._setup_model(model) def fit(self, dataset: Dataset) -> None: self.la.fit(dataset=dataset) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 6acb7884..b011a058 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -6,19 +6,19 @@ from torch.utils.data import DataLoader, Dataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d -from torch_uncertainty.post_processing import BasePostProcessing +from torch_uncertainty.post_processing import PostProcessing -class MCBatchNorm(BasePostProcessing): +class MCBatchNorm(PostProcessing): counter: int = 0 mc_batch_norm_layers: list[MCBatchNorm2d] = [] trained = False def __init__( self, - model: nn.Module, - num_estimators: int, - convert: bool, + model: nn.Module | None = None, + num_estimators: int = 16, + convert: bool = True, mc_batch_size: int = 32, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: @@ -40,29 +40,31 @@ def __init__( batch normalized deep networks. In ICML 2018. """ super().__init__() - self.mc_batch_size = mc_batch_size - if num_estimators < 1 or not isinstance(num_estimators, int): - raise ValueError( - f"num_estimators must be a positive integer, got {num_estimators}." - ) + self.convert = convert self.num_estimators = num_estimators - - self.model = deepcopy(model) - if not convert and not self._has_mcbn(): - raise ValueError( - "model does not contain any MCBatchNorm2d nor is not to be " - "converted." - ) self.device = device + + if model is not None: + self._setup_model(model) + + def _setup_model(self, model): + _mcbn_checks( + model, self.num_estimators, self.mc_batch_size, self.convert + ) + self.model = deepcopy(model) # Is it necessary? self.model = self.model.eval() - if convert: + if self.convert: self._convert() - if not self._has_mcbn(): + if not has_mcbn(self.model): raise ValueError( "model does not contain any MCBatchNorm2d after conversion." ) + def set_model(self, model: nn.Module) -> None: + self.model = model + self._setup_model(model) + def fit(self, dataset: Dataset) -> None: """Fit the model on the dataset. @@ -112,13 +114,6 @@ def forward( [self._est_forward(x) for _ in range(self.num_estimators)], dim=0 ) - def _has_mcbn(self) -> bool: - """Check if the model contains any MCBatchNorm2d layers.""" - for module in self.model.modules(): - if isinstance(module, MCBatchNorm2d): - return True - return False - def _convert(self) -> None: """Convert all BatchNorm2d layers to MCBatchNorm2d layers.""" self.replace_layers(self.model) @@ -172,3 +167,24 @@ def replace_layers(self, model: nn.Module) -> None: # Save pointers to the MC BatchNorm layers self.mc_batch_norm_layers.append(mc_layer) + + +def has_mcbn(model: nn.Module) -> bool: + """Check if the model contains any MCBatchNorm2d layers.""" + return any(isinstance(module, MCBatchNorm2d) for module in model.modules()) + + +def _mcbn_checks(model, num_estimators, mc_batch_size, convert): + if num_estimators < 1 or not isinstance(num_estimators, int): + raise ValueError( + f"num_estimators must be a positive integer, got {num_estimators}." + ) + if mc_batch_size < 1 or not isinstance(mc_batch_size, int): + raise ValueError( + f"mc_batch_size must be a positive integer, got {mc_batch_size}." + ) + if not convert and not has_mcbn(model): + raise ValueError( + "model does not contain any MCBatchNorm2d nor is not to be " + "converted." + ) From e48368dabf5bf4f2a067fa1d4feb1d1708d36f39 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Jun 2024 11:37:39 +0200 Subject: [PATCH 21/57] :sparkles: Add first version of SWAG --- torch_uncertainty/models/wrappers/swag.py | 244 ++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 torch_uncertainty/models/wrappers/swag.py diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py new file mode 100644 index 00000000..06323c72 --- /dev/null +++ b/torch_uncertainty/models/wrappers/swag.py @@ -0,0 +1,244 @@ +import copy + +import torch +from torch import Tensor, nn +from torch.utils.data import DataLoader + +from .swa import SWA + + +def flatten(lst: list[Tensor]) -> Tensor: + tmp = [i.view(-1, 1) for i in lst] + return torch.cat(tmp).view(-1) + + +def unflatten_like(vector, like_tensor_list): + """Takes a flat torch.tensor and unflattens it to a list of torch.tensors + shaped like like_tensor_list. + + """ + out_list = [] + i = 0 + for tensor in like_tensor_list: + n = tensor.numel() + out_list.append(vector[:, i : i + n].view(tensor.shape)) + i += n + return out_list + + +class SWAG(SWA): + def __init__( + self, + model: nn.Module, + cycle_start: int, + cycle_length: int, + scale: float = 1.0, + diag_covariance: bool = True, + max_num_models: int = 20, + var_clamp: float = 1e-30, + num_estimators: int = 16, + ) -> None: + """Stochastic Weight Averaging Gaussian (SWAG). + + Args: + model (nn.Module): PyTorch model to be trained. + cycle_start (int): Epoch to start SWAG. + cycle_length (int): Number of epochs between SWAG updates. + scale (float, optional): Scale of the Gaussian. Defaults to 1.0. + diag_covariance (bool, optional): Whether to use a diagonal covariance. Defaults to False. + max_num_models (int, optional): Maximum number of models to store. Defaults to 0. + var_clamp (float, optional): Minimum variance. Defaults to 1e-30. + num_estimators (int, optional): Number of posterior estimates to use. Defaults to 16. + + Reference: + Maddox, W. J. et al. A simple baseline for bayesian uncertainty in + deep learning. In NeurIPS 2019. + + Note: + Modified from https://github.com/wjmaddox/swa_gaussian + """ + super().__init__(model, cycle_start, cycle_length) + _swag_checks(scale, max_num_models, var_clamp) + + self.num_models = 0 + self.num_estimators = num_estimators + self.scale = scale + + self.diag_covariance = diag_covariance + self.max_num_models = max_num_models + self.var_clamp = var_clamp + + self.swag_params = [] + self.swag_model = copy.deepcopy(model) + self.swag_model.apply(lambda module: self.extract_parameters(module)) + + self.fit = False + self.samples = [] + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.fit: + return self.model.forward(x) + return torch.cat([mod(x) for mod in self.samples]) + + def extract_parameters(self, module: nn.Module) -> None: + for name in list(module._parameters.keys()): + if module._parameters[name] is None: + continue + data = module._parameters[name].data + delattr(module, name) + + mean, squared_mean = torch.zeros_like(data), torch.zeros_like(data) + module.register_buffer(f"{name}_mean", mean) + module.register_buffer(f"{name}_sq_mean", squared_mean) + + if not self.diag_covariance: + covariance_sqrt = torch.zeros((0, data.numel())) + module.register_buffer( + f"{name}_covariance_sqrt", covariance_sqrt + ) + + self.swag_params.append((module, name)) + + @torch.no_grad() + def update_model(self, epoch: int) -> None: + if ( + epoch >= self.cycle_start + and (epoch - self.cycle_start) % self.cycle_length == 0 + ): + print("update SWAG model") + for (module, name), param in zip( + self.swag_params, self.model.parameters(), strict=False + ): + mean = module.__getattr__(f"{name}_mean") + squared_mean = module.__getattr__(f"{name}_sq_mean") + new_param = param.data + + mean = mean * self.num_models / ( + self.num_models + 1 + ) + new_param / (self.num_models + 1) + squared_mean = squared_mean * self.num_models / ( + self.num_models + 1 + ) + new_param**2 / (self.num_models + 1) + + module.__setattr__(f"{name}_mean", mean) + module.__setattr__(f"{name}_sq_mean", squared_mean) + + if not self.diag_covariance: + covariance_sqrt = module.__getattr__( + f"{name}_covariance_sqrt" + ) + dev = (new_param - mean).view(-1, 1).t() + covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) + if self.num_models + 1 > self.max_num_models: + covariance_sqrt = covariance_sqrt[1:, :] + module.__setattr__( + f"{name}_covariance_sqrt", covariance_sqrt + ) + + self.num_models += 1 + + self.samples = [] + for _ in range(self.num_estimators): + self.sample(self.scale, self.diag_covariance) + self.samples.append(copy.deepcopy(self.swag_model)) + self.need_bn_update = True + self.fit = True + + def update_bn(self, loader: DataLoader, device) -> None: + if self.need_bn_update: + for mod in self.samples: + torch.optim.swa_utils.update_bn(loader, mod, device=device) + self.need_bn_update = False + + def sample( + self, + scale: float, + diag_covariance: bool | None = None, + block: bool = False, + seed: int | None = None, + ) -> None: # TODO: Fix sampling + if seed is not None: + torch.manual_seed(seed) + + if diag_covariance is None: + diag_covariance = self.diag_covariance + if not diag_covariance and self.diag_covariance: + raise ValueError( + "Cannot sample full rank from diagonal covariance matrices" + ) + + if not block: + self._fullrank_sample(scale, diag_covariance) + else: + raise NotImplementedError("Raise an issue if you need this feature") + + def _fullrank_sample(self, scale: float, diagonal_covariance: bool) -> None: + mean_list, sq_mean_list = [], [] + if not diagonal_covariance: + cov_mat_sqrt_list = [] + + for module, name in self.swag_params: + mean = module.__getattr__(f"{name}_mean") + sq_mean = module.__getattr__(f"{name}_sq_mean") + + if not diagonal_covariance: + cov_mat_sqrt = module.__getattr__(f"{name}_covariance_sqrt") + cov_mat_sqrt_list.append(cov_mat_sqrt.cpu()) + + mean_list.append(mean.cpu()) + sq_mean_list.append(sq_mean.cpu()) + + mean = flatten(mean_list) + sq_mean = flatten(sq_mean_list) + + # draw diagonal variance sample + var = torch.clamp(sq_mean - mean**2, self.var_clamp) + var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) + + # if covariance draw low rank sample + if not diagonal_covariance: + cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1) + + cov_sample = cov_mat_sqrt.t().matmul( + cov_mat_sqrt.new_empty( + (cov_mat_sqrt.size(0),), requires_grad=False + ).normal_() + ) + # vérifier le min + cov_sample /= (self.max_num_models - 1) ** 0.5 + + rand_sample = var_sample + cov_sample + else: + rand_sample = var_sample + + # update sample with mean and scale + sample = mean + scale**0.5 * rand_sample + sample = sample.unsqueeze(0) + + # unflatten new sample like the mean sample + samples_list = unflatten_like(sample, mean_list) + + for (module, name), sample in zip( + self.swag_params, samples_list, strict=False + ): + module.__setattr__(name, sample.cuda()) + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: + super().load_state_dict(state_dict, strict) + + def compute_logdet(self, block=False): + raise NotImplementedError("Raise an issue if you need this feature") + + def compute_logprob(self, vec=None, block=False, diag=False): + raise NotImplementedError("Raise an issue if you need this feature") + + +def _swag_checks(scale: float, max_num_models: int, var_clamp: float) -> None: + if scale < 0: + raise ValueError(f"`scale` must be non-negative. Got {scale}.") + if max_num_models < 0: + raise ValueError( + f"`max_num_models` must be non-negative. Got {max_num_models}." + ) + if var_clamp < 0: + raise ValueError(f"`var_clamp` must be non-negative. Got {var_clamp}.") From df5330e7ec110826c74953b82ff7848dd6942a09 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Jun 2024 11:38:21 +0200 Subject: [PATCH 22/57] :hammer: Refactor wrappers --- torch_uncertainty/models/__init__.py | 1 + torch_uncertainty/models/wrappers/__init__.py | 3 +- .../models/wrappers/deep_ensembles.py | 2 + torch_uncertainty/models/wrappers/ema.py | 4 +- .../models/wrappers/mc_dropout.py | 51 ++++++++++--------- torch_uncertainty/models/wrappers/swa.py | 8 +++ 6 files changed, 42 insertions(+), 27 deletions(-) diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 82b14a7f..81dc9663 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -4,6 +4,7 @@ EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, SWA, + SWAG, CheckpointEnsemble, MCDropout, deep_ensembles, diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index 4299d3b0..d1895438 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -6,6 +6,7 @@ from .ema import EMA from .mc_dropout import MCDropout, mc_dropout from .swa import SWA +from .swag import SWAG STEP_UPDATE_MODEL = (EMA,) -EPOCH_UPDATE_MODEL = (SWA, CheckpointEnsemble) +EPOCH_UPDATE_MODEL = (SWA, SWAG, CheckpointEnsemble) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 49640108..807ed0ed 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -92,6 +92,8 @@ def deep_ensembles( Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, 2017. """ + if isinstance(models, list) and len(models) == 0: + raise ValueError("Models must not be an empty list.") if (isinstance(models, list) and len(models) == 1) or isinstance( models, nn.Module ): diff --git a/torch_uncertainty/models/wrappers/ema.py b/torch_uncertainty/models/wrappers/ema.py index b6d7476e..253bbd85 100644 --- a/torch_uncertainty/models/wrappers/ema.py +++ b/torch_uncertainty/models/wrappers/ema.py @@ -9,7 +9,7 @@ def __init__( model: nn.Module, momentum: float, ) -> None: - """Exponential moving average model. + """Exponential Moving Average. Args: model (nn.Module): The model to train and ensemble. @@ -22,7 +22,7 @@ def __init__( self.momentum = momentum self.remainder = 1 - momentum - def update_model(self, epoch: int) -> None: + def update_model(self, epoch: int | None = None) -> None: """Update the EMA model. Args: diff --git a/torch_uncertainty/models/wrappers/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py index 78c7c1d6..77c591c9 100644 --- a/torch_uncertainty/models/wrappers/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -33,25 +33,10 @@ def __init__( (i.e. after all the other dropout layers). """ super().__init__() + _dropout_checks(model, num_estimators) self.last_layer = last_layer self.on_batch = on_batch - - if not hasattr(model, "dropout_rate"): - raise ValueError( - "`dropout_rate` must be set in the model to use MC Dropout." - ) - if model.dropout_rate <= 0.0: - raise ValueError( - "`dropout_rate` must be strictly positive to use MC Dropout." - ) self.model = model - - if num_estimators is None: - raise ValueError("`num_estimators` must be set to use MC Dropout.") - if num_estimators <= 0: - raise ValueError( - "`num_estimators` must be strictly positive to use MC Dropout." - ) self.num_estimators = num_estimators self.filtered_modules = list( @@ -84,14 +69,15 @@ def forward( self, x: Tensor, ) -> Tensor: - if not self.training: - if self.on_batch: - x = x.repeat(self.num_estimators, 1, 1, 1) - return self.model(x) - return torch.cat( - [self.model(x) for _ in range(self.num_estimators)], dim=0 - ) - return self.model(x) + if self.training: + return self.model(x) + if self.on_batch: + x = x.repeat(self.num_estimators, 1, 1, 1) + return self.model(x) + # Else, for loop + return torch.cat( + [self.model(x) for _ in range(self.num_estimators)], dim=0 + ) def mc_dropout( @@ -118,3 +104,20 @@ def mc_dropout( last_layer=last_layer, on_batch=on_batch, ) + + +def _dropout_checks(model: nn.Module, num_estimators: int) -> None: + if not hasattr(model, "dropout_rate"): + raise ValueError( + "`dropout_rate` must be set in the model to use MC Dropout." + ) + if model.dropout_rate <= 0.0: + raise ValueError( + "`dropout_rate` must be strictly positive to use MC Dropout." + ) + if num_estimators is None: + raise ValueError("`num_estimators` must be set to use MC Dropout.") + if num_estimators <= 0: + raise ValueError( + "`num_estimators` must be strictly positive to use MC Dropout." + ) diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py index ad9acc15..95589494 100644 --- a/torch_uncertainty/models/wrappers/swa.py +++ b/torch_uncertainty/models/wrappers/swa.py @@ -2,6 +2,7 @@ import torch from torch import nn +from torch.utils.data import DataLoader class SWA(nn.Module): @@ -51,6 +52,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model.forward(x) return self.eval_forward(x) + def update_bn(self, loader: DataLoader, device) -> None: + if self.need_bn_update: + torch.optim.swa_utils.update_bn( + loader, self.swag_model, device=device + ) + self.need_bn_update = False + def _swa_checks(cycle_start: int, cycle_length: int) -> None: if cycle_start < 0: From 1d0c595c393249605f2733d339807ca027b47f33 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Jun 2024 11:40:43 +0200 Subject: [PATCH 23/57] :hammer: Refactor the classification routine --- .../configs/lenet_checkpoint_ensemble.yaml | 1 + .../mnist/configs/lenet_swa.yaml | 1 + .../mnist/configs/lenet_swag.yaml | 63 +++++ experiments/readme.md | 2 +- tests/_dummies/baseline.py | 16 +- .../classification/deep_ensembles.py | 8 +- .../baselines/classification/resnet.py | 4 +- .../baselines/classification/vgg.py | 4 +- .../baselines/classification/wideresnet.py | 4 +- torch_uncertainty/routines/classification.py | 245 ++++++++---------- 10 files changed, 201 insertions(+), 147 deletions(-) create mode 100644 experiments/classification/mnist/configs/lenet_swag.yaml diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index e92e6ca7..c5398a87 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -57,6 +57,7 @@ model: - 70 num_classes: 10 loss: CrossEntropyLoss + is_ensemble: true data: root: ./data batch_size: 128 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index b79647c1..8de26b64 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -47,6 +47,7 @@ model: cycle_length: 5 num_classes: 10 loss: CrossEntropyLoss + is_ensemble: true data: root: ./data batch_size: 128 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml new file mode 100644 index 00000000..bee2e7d3 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -0,0 +1,63 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_swag + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.SWAG + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + cycle_start: 2 + cycle_length: 5 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/readme.md b/experiments/readme.md index 0035c5a7..8f8996c6 100644 --- a/experiments/readme.md +++ b/experiments/readme.md @@ -14,6 +14,6 @@ Torch-Uncertainty proposes various benchmarks to evaluate uncertainty quantifica *Work in progress* -## Monocular Depth Estimation +## Pixel Regression *Work in progress* diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 22cf6e5c..d763800d 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -8,6 +8,8 @@ NormalLayer, ) from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.routines import ( ClassificationRoutine, PixelRegressionRoutine, @@ -24,9 +26,9 @@ def __new__( cls, num_classes: int, in_channels: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", - optim_recipe=None, + optim_recipe=optim_cifar10_resnet18, with_feats: bool = True, with_linear: bool = True, ood_criterion: str = "msp", @@ -69,12 +71,14 @@ def __new__( format_batch_fn=nn.Identity(), log_plots=True, optim_recipe=optim_recipe(model), - num_estimators=1, + is_ensemble=False, mixup_params=mixup_params, ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, - calibration_set="val" if calibrate else None, + post_processing_method=TemperatureScaler() + if calibrate + else None, save_in_csv=save_in_csv, ) # baseline_type == "ensemble": @@ -89,11 +93,11 @@ def __new__( optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), log_plots=True, - num_estimators=2, + is_ensemble=True, ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, - calibration_set="val" if calibrate else None, + post_processing_method=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, ) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index fd6bc6c1..a33480b4 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -24,10 +24,10 @@ def __init__( eval_ood: bool = False, eval_grouping_loss: bool = False, ood_criterion: Literal[ - "msp", "logits", "energy", "entropy", "mi", "VR" + "msp", "logit", "energy", "entropy", "mi", "vr" ] = "msp", log_plots: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", ) -> None: log_path = Path(log_path) @@ -45,14 +45,14 @@ def __init__( optim_recipe=None, ).eval() models.append(trained_model.model) - + print(models) de = deep_ensembles(models=models) super().__init__( num_classes=num_classes, model=de, loss=None, - num_estimators=de.num_estimators, + is_ensemble=de.num_estimators > 1, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index bd41cbe6..9cfd4bf7 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -60,7 +60,7 @@ def __init__( ] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_grouping_loss: bool = False, num_calibration_bins: int = 15, @@ -213,7 +213,7 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in self.ensemble, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 8988b613..0f3a5a59 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -41,7 +41,7 @@ def __init__( ] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_grouping_loss: bool = False, ) -> None: @@ -162,7 +162,7 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in self.ensemble, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 935049cd..b2f5e82d 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -52,7 +52,7 @@ def __init__( ] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_grouping_loss: bool = False, ) -> None: @@ -184,7 +184,7 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in self.ensemble, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 40d50e3d..9f297897 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -36,10 +36,15 @@ from torch_uncertainty.models import ( EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, - CheckpointEnsemble, ) -from torch_uncertainty.post_processing import TemperatureScaler -from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup +from torch_uncertainty.post_processing import PostProcessing +from torch_uncertainty.transforms import ( + Mixup, + MixupIO, + RegMixup, + RepeatTarget, + WarpingMixup, +) from torch_uncertainty.utils import csv_writer, plot_hist MIXUP_PARAMS = { @@ -59,7 +64,7 @@ def __init__( model: nn.Module, num_classes: int, loss: nn.Module, - num_estimators: int = 1, + is_ensemble: bool = False, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, mixup_params: dict | None = None, @@ -68,10 +73,11 @@ def __init__( ood_criterion: Literal[ "msp", "logit", "energy", "entropy", "mi", "vr" ] = "msp", + post_processing_method: PostProcessing | None = None, + calibration_set: Literal["val", "test"] = "val", + num_calibration_bins: int = 15, log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, - num_calibration_bins: int = 15, ) -> None: r"""Routine for training & testing on **classification tasks**. @@ -79,8 +85,8 @@ def __init__( model (torch.nn.Module): Model to train. num_classes (int): Number of classes. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. - num_estimators (int, optional): Number of estimators for the - ensemble. Defaults to ``1`` (single model). + is_ensemble (bool, optional): Indicates whether the model is an + ensemble at test time or not. Defaults to ``False``. format_batch_fn (torch.nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and @@ -94,27 +100,26 @@ def __init__( eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. ood_criterion (str, optional): OOD criterion. Available options are - - ``"msp"`` (default): Maximum softmax probability. - ``"logit"``: Maximum logit. - ``"energy"``: Logsumexp of the mean logits. - ``"entropy"``: Entropy of the mean prediction. - ``"mi"``: Mutual information of the ensemble. - ``"vr"``: Variation ratio of the ensemble. - + post_processing_method (PostProcessing, optional): Post-processing method + to train on the calibration set. No post-processing if None. + Defaults to ``None``. + calibration_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. + num_calibration_bins (int, optional): Number of bins to compute calibration + metrics. Defaults to ``15``. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): Save the results in csv. Defaults to ``False``. - calibration_set (str, optional): The post-hoc calibration dataset to - use for scaling. If not ``None``, it uses either the validation - set when set to ``"val"`` or the test set when set to ``"test"``. - Defaults to ``None``. Else, no post-hoc calibration. - num_calibration_bins (int, optional): Number of bins to compute calibration - metrics. Defaults to ``15``. Warning: - You must define :attr:`optim_recipe` if you do not use the CLI. + You must define :attr:`optim_recipe` if you do not use the Lightning CLI. Note: :attr:`optim_recipe` can be anything that can be returned by @@ -125,17 +130,19 @@ def __init__( _classification_routine_checks( model=model, num_classes=num_classes, - num_estimators=num_estimators, + is_ensemble=is_ensemble, ood_criterion=ood_criterion, eval_grouping_loss=eval_grouping_loss, num_calibration_bins=num_calibration_bins, + mixup_params=mixup_params, + post_processing_method=post_processing_method, + format_batch_fn=format_batch_fn, ) if format_batch_fn is None: format_batch_fn = nn.Identity() self.num_classes = num_classes - self.num_estimators = num_estimators self.eval_ood = eval_ood self.eval_grouping_loss = eval_grouping_loss self.ood_criterion = ood_criterion @@ -145,30 +152,46 @@ def __init__( self.binary_cls = num_classes == 1 self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) - + self.num_calibration_bins = num_calibration_bins self.model = model self.loss = loss self.format_batch_fn = format_batch_fn self.optim_recipe = optim_recipe + self.is_ensemble = is_ensemble + + self.post_processing_method = post_processing_method + if self.post_processing_method is not None: + self.post_processing_method.set_model(self.model) + + self._init_metrics() + self.mixup = self._init_mixup(mixup_params) + + self.is_elbo = isinstance(self.loss, ELBOLoss) + if self.is_elbo: + self.loss.set_model(self.model) + self.is_dec = isinstance(self.loss, DECLoss) - # metrics + self.id_logit_storage = None + self.ood_logit_storage = None + + def _init_metrics(self) -> None: task = "binary" if self.binary_cls else "multiclass" cls_metrics = MetricCollection( { - "cls/Acc": Accuracy(task=task, num_classes=num_classes), - "cls/Brier": BrierScore(num_classes=num_classes), + "cls/Acc": Accuracy(task=task, num_classes=self.num_classes), + "cls/Brier": BrierScore(num_classes=self.num_classes), "cls/NLL": CategoricalNLL(), "cal/ECE": CalibrationError( task=task, - num_bins=num_calibration_bins, - num_classes=num_classes, + num_bins=self.num_calibration_bins, + num_classes=self.num_classes, ), "cal/aECE": CalibrationError( task=task, adaptive=True, - num_bins=num_calibration_bins, - num_classes=num_classes, + num_bins=self.num_calibration_bins, + num_classes=self.num_classes, ), "sc/AURC": AURC(), "sc/CovAt5Risk": CovAt5Risk(), @@ -204,7 +227,7 @@ def __init__( self.test_ood_entropy = Entropy() # metrics for ensembles only - if self.num_estimators > 1 or isinstance(model, CheckpointEnsemble): + if self.is_ensemble: ens_metrics = MetricCollection( { "Disagreement": Disagreement(), @@ -218,23 +241,12 @@ def __init__( if self.eval_ood: self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") - if num_estimators == 1: - self.mixup = self._init_mixup(mixup_params) - - if self.eval_grouping_loss: - grouping_loss = MetricCollection( - {"cls/grouping_loss": GroupingLoss()} - ) - self.val_grouping_loss = grouping_loss.clone(prefix="val/") - self.test_grouping_loss = grouping_loss.clone(prefix="test/") - - self.is_elbo = isinstance(self.loss, ELBOLoss) - if self.is_elbo: - self.loss.set_model(self.model) - self.is_dec = isinstance(self.loss, DECLoss) - - self.id_logit_storage = None - self.ood_logit_storage = None + if self.eval_grouping_loss: + grouping_loss = MetricCollection( + {"cls/grouping_loss": GroupingLoss()} + ) + self.val_grouping_loss = grouping_loss.clone(prefix="val/") + self.test_grouping_loss = grouping_loss.clone(prefix="test/") def _init_mixup(self, mixup_params: dict | None) -> Callable: if mixup_params is None: @@ -284,6 +296,21 @@ def _init_mixup(self, mixup_params: dict | None) -> Callable: ) return Identity() + def _apply_mixup( + self, batch: tuple[Tensor, Tensor] + ) -> tuple[Tensor, Tensor]: + if not self.is_ensemble: + if self.mixup_params["mixtype"] == "kernel_warping": + if self.mixup_params["dist_sim"] == "emb": + with torch.no_grad(): + feats = self.model.feats_forward(batch[0]).detach() + batch = self.mixup(*batch, feats) + elif self.mixup_params["dist_sim"] == "inp": + batch = self.mixup(*batch, batch[0]) + else: + batch = self.mixup(*batch) + return batch + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -294,40 +321,22 @@ def on_train_start(self) -> None: ) def on_validation_start(self) -> None: - if ( - self.need_epoch_update and self.current_epoch > 0 - ): # workaround of sanity checks + if self.need_epoch_update and not self.trainer.sanity_checking: self.model.update_model(self.current_epoch) - if ( - hasattr(self.model, "need_bn_update") - and self.model.need_bn_update - ): - torch.optim.swa_utils.update_bn( - self.trainer.train_dataloader, - self.model, - device=self.device, + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device ) - self.model.need_bn_update = False - - if isinstance(self.model, CheckpointEnsemble): - self.num_estimators = self.model.num_estimators def on_test_start(self) -> None: - if isinstance(self.calibration_set, str) and self.calibration_set in [ - "val", - "test", - ]: + if self.post_processing_method is not None: calibration_dataset = ( self.trainer.datamodule.val_dataloader().dataset if self.calibration_set == "val" else self.trainer.datamodule.test_dataloader()[0].dataset ) with torch.inference_mode(False): - self.cal_model = TemperatureScaler( - model=self.model, device=self.device - ).fit(calibration_dataset) - else: - self.cal_model = None + self.post_processing_method.fit(calibration_dataset) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] @@ -355,23 +364,6 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: logits = self.model(inputs) return logits - def _apply_mixup( - self, batch: tuple[Tensor, Tensor] - ) -> tuple[Tensor, Tensor]: - # Mixup only for single models - if self.num_estimators == 1: - if self.mixup_params["mixtype"] == "kernel_warping": - if self.mixup_params["dist_sim"] == "emb": - with torch.no_grad(): - feats = self.model.feats_forward(batch[0]).detach() - - batch = self.mixup(*batch, feats) - elif self.mixup_params["dist_sim"] == "inp": - batch = self.mixup(*batch, batch[0]) - else: - batch = self.mixup(*batch) - return batch - def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: @@ -399,9 +391,9 @@ def training_step( def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - inputs, target = batch + inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) @@ -409,10 +401,10 @@ def validation_step( probs_per_est = F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - self.val_cls_metrics.update(probs, target) + self.val_cls_metrics.update(probs, targets) if self.eval_grouping_loss: - self.val_grouping_loss.update(probs, target, self.features) + self.val_grouping_loss.update(probs, targets, self.features) def test_step( self, @@ -420,18 +412,9 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: - inputs, target = batch - logits = self.forward( - inputs, save_feats=self.eval_grouping_loss - ) # (m*b, c) - if logits.size(0) % self.num_estimators != 0: # coverage: ignore - raise ValueError( - f"The number of predicted samples {logits.size(0)} is not " - "divisible by the reported number of estimators " - f"{self.num_estimators} of the routine. Please check the " - "correspondence between these values." - ) - logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) + inputs, targets = batch + logits = self.forward(inputs, save_feats=self.eval_grouping_loss) + logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) if self.binary_cls: probs_per_est = torch.sigmoid(logits) @@ -459,20 +442,19 @@ def test_step( else: ood_scores = -confs - # Scaling for single models - if self.num_estimators == 1 and self.cal_model is not None: - cal_logits = self.cal_model(inputs) - cal_probs = F.softmax(cal_logits, dim=-1) - self.ts_cls_metrics.update(cal_probs, target) + if self.post_processing_method is not None: + pp_logits = self.post_processing_method(inputs) + pp_probs = F.softmax(pp_logits, dim=-1) + self.ts_cls_metrics.update(pp_probs, targets) if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( probs.squeeze(-1) if self.binary_cls else probs, - target, + targets, ) if self.eval_grouping_loss: - self.test_grouping_loss.update(probs, target, self.features) + self.test_grouping_loss.update(probs, targets, self.features) self.log_dict( self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False @@ -485,19 +467,19 @@ def test_step( add_dataloader_idx=False, ) - if self.num_estimators > 1: + if self.is_ensemble > 1: self.test_id_ens_metrics.update(probs_per_est) if self.eval_ood: self.test_ood_metrics.update( - ood_scores, torch.zeros_like(target) + ood_scores, torch.zeros_like(targets) ) if self.id_logit_storage is not None: self.id_logit_storage.append(logits.detach().cpu()) elif self.eval_ood and dataloader_idx == 1: - self.test_ood_metrics.update(ood_scores, torch.ones_like(target)) + self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_ood_entropy(probs) self.log( "ood/Entropy", @@ -505,7 +487,7 @@ def test_step( on_epoch=True, add_dataloader_idx=False, ) - if self.num_estimators > 1: + if self.is_ensemble: self.test_ood_ens_metrics.update(probs_per_est) if self.ood_logit_storage is not None: @@ -528,11 +510,7 @@ def on_test_epoch_end(self) -> None: {"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True ) - if ( - self.num_estimators == 1 - and self.calibration_set is not None - and self.cal_model is not None - ): + if self.post_processing_method is not None: tmp_metrics = self.ts_cls_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -543,7 +521,7 @@ def on_test_epoch_end(self) -> None: sync_dist=True, ) - if self.num_estimators > 1: + if self.is_ensemble: tmp_metrics = self.test_id_ens_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -556,7 +534,7 @@ def on_test_epoch_end(self) -> None: # already logged result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) - if self.num_estimators > 1: + if self.is_ensemble: tmp_metrics = self.test_ood_ens_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -570,7 +548,7 @@ def on_test_epoch_end(self) -> None: self.test_cls_metrics["sc/AURC"].plot()[0], ) - if self.cal_model is not None: + if self.post_processing_method is not None: self.logger.experiment.add_figure( "Reliabity diagram after calibration", self.ts_cls_metrics["cal/ECE"].plot()[0], @@ -619,17 +597,14 @@ def save_results_to_csv(self, results: dict[str, float]) -> None: def _classification_routine_checks( model: nn.Module, num_classes: int, - num_estimators: int, + is_ensemble: bool, ood_criterion: str, eval_grouping_loss: bool, num_calibration_bins: int, + mixup_params: dict | None, + post_processing_method: PostProcessing | None, + format_batch_fn: nn.Module | None, ) -> None: - if not isinstance(num_estimators, int) or num_estimators < 1: - raise ValueError( - "The number of estimators must be a positive integer >= 1." - f"Got {num_estimators}." - ) - if ood_criterion not in [ "msp", "logit", @@ -643,13 +618,13 @@ def _classification_routine_checks( f" 'mi' or 'vr'. Got {ood_criterion}." ) - if num_estimators == 1 and ood_criterion in ["mi", "vr"]: + if not is_ensemble and ood_criterion in ["mi", "vr"]: raise ValueError( "You cannot use mutual information or variation ratio with a single" " model." ) - if num_estimators != 1 and eval_grouping_loss: + if is_ensemble and eval_grouping_loss: raise NotImplementedError( "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." ) @@ -678,3 +653,13 @@ def _classification_routine_checks( raise ValueError( f"num_calibration_bins must be at least 2, got {num_calibration_bins}." ) + + if mixup_params is not None and isinstance(format_batch_fn, RepeatTarget): + raise ValueError( + "Mixup is not supported for ensembles at training time. Please set mixup_params to None." + ) + + if post_processing_method is not None and is_ensemble: + raise ValueError( + "Ensembles and post-processing methods cannot be used together. Raise an issue if needed." + ) From 6e989f65d7581e02114fdccbd73fa947aa744fab Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Jun 2024 11:41:21 +0200 Subject: [PATCH 24/57] :book: Add links to the conf. in ReadMe --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 74b6af12..072439c1 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ _TorchUncertainty_ is a package designed to help you leverage [uncertainty quant :books: Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). :books: -TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **WACV 2024** and **ECCV 2024**. +TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **[WACV](https://wacv2024.thecvf.com/) 2024**, **[HAICON](https://haicon24.de/) 2024** and **[ECCV](https://eccv.ecva.net/) 2024**. --- @@ -69,8 +69,8 @@ To date, the following deep learning baselines have been implemented: - MIMO - Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) -- Checkpoint Ensembles -- Stochastic Weight Averaging +- Checkpoint Ensembles & Snapshot Ensembles +- Stochastic Weight Averaging & Stochastic Weight Averaging Gaussian - Regression with Beta Gaussian NLL Loss - Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) From fbc8c551df78aecd7a55180aa7aad31d4cd847c0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Jun 2024 11:53:26 +0200 Subject: [PATCH 25/57] :white_check_mark: Update tests --- tests/baselines/test_deep_ensembles.py | 4 +++- tests/routines/test_classification.py | 22 ++-------------------- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index fbb7a512..cd8642cf 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -9,7 +9,9 @@ class TestDeepEnsembles: """Testing the Deep Ensembles baseline class.""" def test_failure(self): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Models must not be an empty list." + ): DeepEnsemblesBaseline( log_path=".", checkpoint_ids=[], diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 4c40da79..ebd6a736 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -9,7 +9,6 @@ dummy_model, ) from torch_uncertainty.losses import DECLoss, ELBOLoss -from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine from torch_uncertainty.utils import TUTrainer @@ -30,7 +29,6 @@ def test_one_estimator_binary(self): in_channels=dm.num_channels, num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="msp", ) @@ -53,7 +51,6 @@ def test_two_estimators_binary(self): in_channels=dm.num_channels, num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="logit", ) @@ -77,7 +74,6 @@ def test_one_estimator_two_classes(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -103,7 +99,6 @@ def test_one_estimator_two_classes_timm(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -131,7 +126,6 @@ def test_one_estimator_two_classes_mixup(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -158,7 +152,6 @@ def test_one_estimator_two_classes_mixup_io(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -185,7 +178,6 @@ def test_one_estimator_two_classes_regmixup(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -212,7 +204,6 @@ def test_one_estimator_two_classes_kernel_warping_emb(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -239,7 +230,6 @@ def test_one_estimator_two_classes_kernel_warping_inp(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -267,7 +257,6 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="energy", eval_ood=True, @@ -294,7 +283,6 @@ def test_two_estimators_two_classes_mi(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=DECLoss(1, 1e-2), - optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="mi", eval_ood=True, @@ -328,7 +316,6 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): loss=ELBOLoss( None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4 ), - optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="vr", eval_ood=True, @@ -341,11 +328,6 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): model(dm.get_test_set()[0][0]) def test_classification_failures(self): - # num_estimators - with pytest.raises(ValueError): - ClassificationRoutine( - num_classes=10, model=nn.Module(), loss=None, num_estimators=-1 - ) # num_classes with pytest.raises(ValueError): ClassificationRoutine(num_classes=0, model=nn.Module(), loss=None) @@ -355,7 +337,7 @@ def test_classification_failures(self): num_classes=10, model=nn.Module(), loss=None, - num_estimators=1, + is_ensemble=False, ood_criterion="mi", ) with pytest.raises(ValueError): @@ -398,7 +380,7 @@ def test_classification_failures(self): num_classes=10, model=nn.Module(), loss=None, - num_estimators=2, + is_ensemble=True, eval_grouping_loss=True, ) From 27bb61023a850a51dc06892496b1259393d02b05 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 11:24:14 +0200 Subject: [PATCH 26/57] :sparkles: Improve SWAG code --- torch_uncertainty/models/wrappers/swag.py | 257 ++++++++++------------ 1 file changed, 119 insertions(+), 138 deletions(-) diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 06323c72..04b84a08 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -1,61 +1,51 @@ import copy import torch +from swa import SWA from torch import Tensor, nn from torch.utils.data import DataLoader -from .swa import SWA - - -def flatten(lst: list[Tensor]) -> Tensor: - tmp = [i.view(-1, 1) for i in lst] - return torch.cat(tmp).view(-1) - - -def unflatten_like(vector, like_tensor_list): - """Takes a flat torch.tensor and unflattens it to a list of torch.tensors - shaped like like_tensor_list. - - """ - out_list = [] - i = 0 - for tensor in like_tensor_list: - n = tensor.numel() - out_list.append(vector[:, i : i + n].view(tensor.shape)) - i += n - return out_list - class SWAG(SWA): + swag_stats: dict[str, Tensor] + def __init__( self, model: nn.Module, cycle_start: int, cycle_length: int, scale: float = 1.0, - diag_covariance: bool = True, + diag_covariance: bool = False, max_num_models: int = 20, var_clamp: float = 1e-30, num_estimators: int = 16, ) -> None: """Stochastic Weight Averaging Gaussian (SWAG). + Update the SWAG posterior every `cycle_length` epochs starting at + `cycle_start`. Samples :attr:`num_estimators` models from the SWAG + posterior after each update. Uses the SWAG posterior estimation only + at test time. Otherwise, uses the base model for training. + Args: model (nn.Module): PyTorch model to be trained. cycle_start (int): Epoch to start SWAG. cycle_length (int): Number of epochs between SWAG updates. scale (float, optional): Scale of the Gaussian. Defaults to 1.0. - diag_covariance (bool, optional): Whether to use a diagonal covariance. Defaults to False. - max_num_models (int, optional): Maximum number of models to store. Defaults to 0. + diag_covariance (bool, optional): Whether to use a diagonal + covariance. Defaults to False. + max_num_models (int, optional): Maximum number of models to store. + Defaults to 0. var_clamp (float, optional): Minimum variance. Defaults to 1e-30. - num_estimators (int, optional): Number of posterior estimates to use. Defaults to 16. + num_estimators (int, optional): Number of posterior estimates to + use. Defaults to 16. Reference: Maddox, W. J. et al. A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019. Note: - Modified from https://github.com/wjmaddox/swa_gaussian + Originates from https://github.com/wjmaddox/swa_gaussian. """ super().__init__(model, cycle_start, cycle_length) _swag_checks(scale, max_num_models, var_clamp) @@ -68,10 +58,7 @@ def __init__( self.max_num_models = max_num_models self.var_clamp = var_clamp - self.swag_params = [] - self.swag_model = copy.deepcopy(model) - self.swag_model.apply(lambda module: self.extract_parameters(module)) - + self.initialize_stats() self.fit = False self.samples = [] @@ -80,71 +67,77 @@ def eval_forward(self, x: torch.Tensor) -> torch.Tensor: return self.model.forward(x) return torch.cat([mod(x) for mod in self.samples]) - def extract_parameters(self, module: nn.Module) -> None: - for name in list(module._parameters.keys()): - if module._parameters[name] is None: - continue - data = module._parameters[name].data - delattr(module, name) - - mean, squared_mean = torch.zeros_like(data), torch.zeros_like(data) - module.register_buffer(f"{name}_mean", mean) - module.register_buffer(f"{name}_sq_mean", squared_mean) + def initialize_stats(self) -> None: + """Initialize the SWAG dictionary of statistics.""" + self.swag_stats = {} + for name_p, param in self.model.named_parameters(): + mean, squared_mean = ( + torch.zeros_like(param, device="cpu"), + torch.zeros_like(param, device="cpu"), + ) + self.swag_stats[name_p + "_mean"] = mean + self.swag_stats[name_p + "_sq_mean"] = squared_mean if not self.diag_covariance: - covariance_sqrt = torch.zeros((0, data.numel())) - module.register_buffer( - f"{name}_covariance_sqrt", covariance_sqrt - ) - - self.swag_params.append((module, name)) + covariance_sqrt = torch.zeros((0, param.numel()), device="cpu") + self.swag_stats[name_p + "_covariance_sqrt"] = covariance_sqrt @torch.no_grad() def update_model(self, epoch: int) -> None: - if ( + """Update the SWAG posterior. + + The update is performed if the epoch is greater than the cycle start + and the difference between the epoch and the cycle start is a multiple + of the cycle length. + + Args: + epoch (int): Current epoch. + """ + if not ( epoch >= self.cycle_start and (epoch - self.cycle_start) % self.cycle_length == 0 ): - print("update SWAG model") - for (module, name), param in zip( - self.swag_params, self.model.parameters(), strict=False - ): - mean = module.__getattr__(f"{name}_mean") - squared_mean = module.__getattr__(f"{name}_sq_mean") - new_param = param.data - - mean = mean * self.num_models / ( - self.num_models + 1 - ) + new_param / (self.num_models + 1) - squared_mean = squared_mean * self.num_models / ( - self.num_models + 1 - ) + new_param**2 / (self.num_models + 1) - - module.__setattr__(f"{name}_mean", mean) - module.__setattr__(f"{name}_sq_mean", squared_mean) - - if not self.diag_covariance: - covariance_sqrt = module.__getattr__( - f"{name}_covariance_sqrt" - ) - dev = (new_param - mean).view(-1, 1).t() - covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) - if self.num_models + 1 > self.max_num_models: - covariance_sqrt = covariance_sqrt[1:, :] - module.__setattr__( - f"{name}_covariance_sqrt", covariance_sqrt - ) - - self.num_models += 1 - - self.samples = [] - for _ in range(self.num_estimators): - self.sample(self.scale, self.diag_covariance) - self.samples.append(copy.deepcopy(self.swag_model)) - self.need_bn_update = True - self.fit = True + return + + for name_p, param in self.model.named_parameters(): + mean = self.swag_stats[name_p + "_mean"] + squared_mean = self.swag_stats[name_p + "_sq_mean"] + new_param = param.data.detach().cpu() + + mean = mean * self.num_models / ( + self.num_models + 1 + ) + new_param / (self.num_models + 1) + squared_mean = squared_mean * self.num_models / ( + self.num_models + 1 + ) + new_param**2 / (self.num_models + 1) + + self.swag_stats[name_p + "_mean"] = mean + self.swag_stats[name_p + "_sq_mean"] = squared_mean + + if not self.diag_covariance: + covariance_sqrt = self.swag_stats[name_p + "_covariance_sqrt"] + dev = (new_param - mean).view(-1, 1).t() + covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) + if self.num_models + 1 > self.max_num_models: + covariance_sqrt = covariance_sqrt[1:, :] + self.swag_stats[name_p + "_covariance_sqrt"] = covariance_sqrt + + self.num_models += 1 + + self.samples = [ + self.sample(self.scale, self.diag_covariance) + for _ in range(self.num_estimators) + ] + self.need_bn_update = True + self.fit = True def update_bn(self, loader: DataLoader, device) -> None: + """Update the bachnorm statistics of the current SWAG samples. + + Args: + loader (DataLoader): DataLoader to update the batchnorm statistics. + device (torch.device): Device to perform the update. + """ if self.need_bn_update: for mod in self.samples: torch.optim.swa_utils.update_bn(loader, mod, device=device) @@ -156,7 +149,20 @@ def sample( diag_covariance: bool | None = None, block: bool = False, seed: int | None = None, - ) -> None: # TODO: Fix sampling + ) -> nn.Module: + """Sample a model from the SWAG posterior. + + Args: + scale (float): Rescale coefficient of the Gaussian. + diag_covariance (bool, optional): Whether to use a diagonal + covariance. Defaults to None. + block (bool, optional): Whether to sample a block diagonal + covariance. Defaults to False. + seed (int, optional): Random seed. Defaults to None. + + Returns: + nn.Module: Sampled model. + """ if seed is not None: torch.manual_seed(seed) @@ -164,67 +170,42 @@ def sample( diag_covariance = self.diag_covariance if not diag_covariance and self.diag_covariance: raise ValueError( - "Cannot sample full rank from diagonal covariance matrices" + "Cannot sample full rank from diagonal covariance matrix." ) if not block: - self._fullrank_sample(scale, diag_covariance) - else: - raise NotImplementedError("Raise an issue if you need this feature") + return self._fullrank_sample(scale, diag_covariance) + raise NotImplementedError("Raise an issue if you need this feature.") - def _fullrank_sample(self, scale: float, diagonal_covariance: bool) -> None: - mean_list, sq_mean_list = [], [] - if not diagonal_covariance: - cov_mat_sqrt_list = [] + def _fullrank_sample( + self, scale: float, diagonal_covariance: bool + ) -> nn.Module: + new_sample = copy.deepcopy(self.model) - for module, name in self.swag_params: - mean = module.__getattr__(f"{name}_mean") - sq_mean = module.__getattr__(f"{name}_sq_mean") + for name_p, param in new_sample.named_parameters(): + mean = self.swag_stats[name_p + "_mean"] + sq_mean = self.swag_stats[name_p + "_sq_mean"] if not diagonal_covariance: - cov_mat_sqrt = module.__getattr__(f"{name}_covariance_sqrt") - cov_mat_sqrt_list.append(cov_mat_sqrt.cpu()) - - mean_list.append(mean.cpu()) - sq_mean_list.append(sq_mean.cpu()) - - mean = flatten(mean_list) - sq_mean = flatten(sq_mean_list) + cov_mat_sqrt = self.swag_stats[name_p + "_covariance_sqrt"] - # draw diagonal variance sample - var = torch.clamp(sq_mean - mean**2, self.var_clamp) - var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) + # draw diagonal variance sample + var = torch.clamp(sq_mean - mean**2, self.var_clamp) + var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) - # if covariance draw low rank sample - if not diagonal_covariance: - cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1) - - cov_sample = cov_mat_sqrt.t().matmul( - cov_mat_sqrt.new_empty( - (cov_mat_sqrt.size(0),), requires_grad=False - ).normal_() - ) - # vérifier le min - cov_sample /= (self.max_num_models - 1) ** 0.5 - - rand_sample = var_sample + cov_sample - else: - rand_sample = var_sample - - # update sample with mean and scale - sample = mean + scale**0.5 * rand_sample - sample = sample.unsqueeze(0) - - # unflatten new sample like the mean sample - samples_list = unflatten_like(sample, mean_list) - - for (module, name), sample in zip( - self.swag_params, samples_list, strict=False - ): - module.__setattr__(name, sample.cuda()) - - def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: - super().load_state_dict(state_dict, strict) + # if covariance draw low rank sample + if not diagonal_covariance: + cov_sample = cov_mat_sqrt.t() @ torch.randn( + (cov_mat_sqrt.size(0),) + ) + cov_sample /= (self.max_num_models - 1) ** 0.5 + rand_sample = var_sample + cov_sample.view_as(var_sample) + else: + rand_sample = var_sample + + sample = mean + scale**0.5 * rand_sample + param.data = sample.to(device=param.device, dtype=param.dtype) + return new_sample def compute_logdet(self, block=False): raise NotImplementedError("Raise an issue if you need this feature") From 0010d95bc628bf07d63ff214f34b24b6c4a422db Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 11:28:16 +0200 Subject: [PATCH 27/57] :shirt: Minor fix --- torch_uncertainty/layers/packed.py | 20 ++++++++++---------- torch_uncertainty/models/wrappers/swag.py | 3 ++- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 6f742b17..094b4834 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -6,12 +6,12 @@ def check_packed_parameters_consistency( - alpha: float, gamma: int, num_estimators: int + alpha: int, gamma: int, num_estimators: int ) -> None: """Check the consistency of the parameters of the Packed-Ensembles layers. Args: - alpha (float): The width multiplier of the layer. + alpha (int): The width multiplier of the layer. gamma (int): The number of groups in the ensemble. num_estimators (int): The number of estimators in the ensemble. """ @@ -49,7 +49,7 @@ def __init__( self, in_features: int, out_features: int, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, bias: bool = True, @@ -67,7 +67,7 @@ def __init__( Args: in_features (int): Number of input features of the linear layer. out_features (int): Number of channels produced by the linear layer. - alpha (float): The width multiplier of the linear layer. + alpha (int): The width multiplier of the linear layer. num_estimators (int): The number of estimators grouped in the layer. gamma (int, optional): Defaults to ``1``. bias (bool, optional): It ``True``, adds a learnable bias to the @@ -174,7 +174,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: _size_1_t, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, stride: _size_1_t = 1, @@ -195,7 +195,7 @@ def __init__( in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. - alpha (float): The channel multiplier of the convolutional layer. + alpha (int): The channel multiplier of the convolutional layer. num_estimators (int): Number of estimators in the ensemble. gamma (int, optional): Defaults to ``1``. stride (int or tuple, optional): Stride of the convolution. @@ -302,7 +302,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: _size_2_t, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, stride: _size_2_t = 1, @@ -323,7 +323,7 @@ def __init__( in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. - alpha (float): The channel multiplier of the convolutional layer. + alpha (int): The channel multiplier of the convolutional layer. num_estimators (int): Number of estimators in the ensemble. gamma (int, optional): Defaults to ``1``. stride (int or tuple, optional): Stride of the convolution. @@ -430,7 +430,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: _size_3_t, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, stride: _size_3_t = 1, @@ -451,7 +451,7 @@ def __init__( in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. - alpha (float): The channel multiplier of the convolutional layer. + alpha (int): The channel multiplier of the convolutional layer. num_estimators (int): Number of estimators in the ensemble. gamma (int, optional): Defaults to ``1``. stride (int or tuple, optional): Stride of the convolution. diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 04b84a08..90a514af 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -1,10 +1,11 @@ import copy import torch -from swa import SWA from torch import Tensor, nn from torch.utils.data import DataLoader +from .swa import SWA + class SWAG(SWA): swag_stats: dict[str, Tensor] From 2c63b3bd5c1f0ea5fd97211f67494736a6afe4d2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 13:22:58 +0200 Subject: [PATCH 28/57] :wrench: Fix online install --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 822924d0..a8ab9e8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "huggingface-hub", "scikit-learn", "matplotlib", + "numpy<2", "opencv-python", "glest==0.0.1a0", ] From 914599b111cc63e33d7e96fee9bab73a703e058f Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 15:24:48 +0200 Subject: [PATCH 29/57] :sparkles: Add a full scheduler for SWA & SWAG & update config --- .../classification/mnist/configs/lenet.yaml | 10 +++-- .../mnist/configs/lenet_swa.yaml | 9 ++-- .../mnist/configs/lenet_swag.yaml | 11 ++--- experiments/classification/mnist/lenet.py | 1 - torch_uncertainty/optim_recipes.py | 42 +++++++++++++++++++ 5 files changed, 59 insertions(+), 14 deletions(-) diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index c4635dfe..0c7989ab 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -51,7 +51,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 8de26b64..fa3eb77d 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -57,7 +57,8 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch_uncertainty.optim_recipes.FullSWALR + init_args: + milestone: 20 + swa_lr: 0.01 + anneal_epochs: 5 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index bee2e7d3..292b49f0 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -43,7 +43,7 @@ model: dropout_rate: 0 last_layer_dropout: false layer_args: {} - cycle_start: 2 + cycle_start: 10 cycle_length: 5 num_classes: 10 loss: CrossEntropyLoss @@ -57,7 +57,8 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch_uncertainty.optim_recipes.FullSWALR + init_args: + milestone: 10 + swa_lr: 0.01 + anneal_epochs: 5 diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index ffb2b8f2..c93610f4 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -9,7 +9,6 @@ class MNISTCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_optimizer_args(torch.optim.SGD) - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) def cli_main() -> MNISTCLI: diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index a413b02c..e9edd107 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -1,6 +1,8 @@ from collections.abc import Callable from functools import partial +from typing import Literal +import torch from timm.optim import Lamb from torch import nn, optim from torch.optim import Optimizer @@ -422,3 +424,43 @@ def get_procedure( procedure = partial(batch_ensemble_wrapper, optim_recipe=procedure) return procedure + + +class FullSWALR(torch.optim.lr_scheduler.SequentialLR): + def __init__( + self, + optimizer: Optimizer, + milestone: int, + swa_lr: float, + anneal_epochs: int, + optim_eta_min: float = 0, + anneal_strategy: Literal["cos", "linear"] = "cos", + ) -> None: + """Chains a Cosine scheduler and a SWA scheduler. + + This class is an example of a wrapper to enable training SWA and SWAG + models using the CLI. You may create your own class following this + example. + + Args: + optimizer (Optimizer): The optimizer to be used. + milestone (int): The epoch to start the SWA. + swa_lr (float): The learning rate to use for the SWA model. + anneal_epochs (int): The number of epochs to anneal the learning rate. + optim_eta_min (float): The minimum learning rate for the first optimizer. + anneal_strategy (Literal["cos", "linear"]): The strategy to anneal the learning rate. + """ + optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=milestone, eta_min=optim_eta_min + ) + swa_scheduler = torch.optim.swa_utils.SWALR( + optimizer, + swa_lr=swa_lr, + anneal_epochs=anneal_epochs, + anneal_strategy=anneal_strategy, + ) + super().__init__( + optimizer=optimizer, + schedulers=[optim_scheduler, swa_scheduler], + milestones=[milestone], + ) From a3443e3b3cc839251d12b5733458d89ad8d4ab26 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 15:29:21 +0200 Subject: [PATCH 30/57] :bug: Improve SWA & SWAG --- torch_uncertainty/models/wrappers/swa.py | 33 ++++-- torch_uncertainty/models/wrappers/swag.py | 106 ++++++++++++++----- torch_uncertainty/routines/classification.py | 7 +- 3 files changed, 114 insertions(+), 32 deletions(-) diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py index 95589494..05aba48d 100644 --- a/torch_uncertainty/models/wrappers/swa.py +++ b/torch_uncertainty/models/wrappers/swa.py @@ -1,23 +1,42 @@ import copy import torch -from torch import nn +from torch import Tensor, nn from torch.utils.data import DataLoader class SWA(nn.Module): + num_avgd_models: Tensor + def __init__( self, model: nn.Module, cycle_start: int, cycle_length: int, ) -> None: + """Stochastic Weight Averaging. + + Update the SWA model every :attr:`cycle_length` epochs starting at + :attr:`cycle_start`. Uses the SWA model only at test time. Otherwise, + uses the base model for training. + + Args: + model (nn.Module): PyTorch model to be trained. + cycle_start (int): Epoch to start SWA. + cycle_length (int): Number of epochs between SWA updates. + + Reference: + Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. + (2018). Averaging Weights Leads to Wider Optima and Better Generalization. + In NeurIPS 2018. + """ super().__init__() _swa_checks(cycle_start, cycle_length) self.model = model self.cycle_start = cycle_start self.cycle_length = cycle_length - self.num_averaged = 0 + + self.register_buffer("num_avgd_models", torch.tensor(0, device="cpu")) self.swa_model = None self.need_bn_update = False @@ -29,7 +48,7 @@ def update_model(self, epoch: int) -> None: ): if self.swa_model is None: self.swa_model = copy.deepcopy(self.model) - self.num_averaged = 1 + self.num_avgd_models = torch.tensor(1) else: for swa_param, param in zip( self.swa_model.parameters(), @@ -37,9 +56,9 @@ def update_model(self, epoch: int) -> None: strict=False, ): swa_param.data += (param.data - swa_param.data) / ( - self.num_averaged + 1 + self.num_avgd_models + 1 ) - self.num_averaged += 1 + self.num_avgd_models += 1 self.need_bn_update = True def eval_forward(self, x: torch.Tensor) -> torch.Tensor: @@ -53,9 +72,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.eval_forward(x) def update_bn(self, loader: DataLoader, device) -> None: - if self.need_bn_update: + if self.need_bn_update and self.swa_model is not None: torch.optim.swa_utils.update_bn( - loader, self.swag_model, device=device + loader, self.swa_model, device=device ) self.need_bn_update = False diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 90a514af..4a44978e 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -1,4 +1,5 @@ import copy +from collections.abc import Mapping import torch from torch import Tensor, nn @@ -9,6 +10,7 @@ class SWAG(SWA): swag_stats: dict[str, Tensor] + prfx = "model.swag_stats." def __init__( self, @@ -18,7 +20,7 @@ def __init__( scale: float = 1.0, diag_covariance: bool = False, max_num_models: int = 20, - var_clamp: float = 1e-30, + var_clamp: float = 1e-6, num_estimators: int = 16, ) -> None: """Stochastic Weight Averaging Gaussian (SWAG). @@ -30,8 +32,9 @@ def __init__( Args: model (nn.Module): PyTorch model to be trained. - cycle_start (int): Epoch to start SWAG. - cycle_length (int): Number of epochs between SWAG updates. + cycle_start (int): Begininning of the first SWAG averaging cycle. + cycle_length (int): Number of epochs between SWAG updates. The + first update occurs at :attr:`cycle_start`+:attr:`cycle_length`. scale (float, optional): Scale of the Gaussian. Defaults to 1.0. diag_covariance (bool, optional): Whether to use a diagonal covariance. Defaults to False. @@ -66,7 +69,7 @@ def __init__( def eval_forward(self, x: torch.Tensor) -> torch.Tensor: if not self.fit: return self.model.forward(x) - return torch.cat([mod(x) for mod in self.samples]) + return torch.cat([mod.to(device=x.device)(x) for mod in self.samples]) def initialize_stats(self) -> None: """Initialize the SWAG dictionary of statistics.""" @@ -76,12 +79,14 @@ def initialize_stats(self) -> None: torch.zeros_like(param, device="cpu"), torch.zeros_like(param, device="cpu"), ) - self.swag_stats[name_p + "_mean"] = mean - self.swag_stats[name_p + "_sq_mean"] = squared_mean + self.swag_stats[self.prfx + name_p + "_mean"] = mean + self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean if not self.diag_covariance: covariance_sqrt = torch.zeros((0, param.numel()), device="cpu") - self.swag_stats[name_p + "_covariance_sqrt"] = covariance_sqrt + self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( + covariance_sqrt + ) @torch.no_grad() def update_model(self, epoch: int) -> None: @@ -95,14 +100,14 @@ def update_model(self, epoch: int) -> None: epoch (int): Current epoch. """ if not ( - epoch >= self.cycle_start + epoch > self.cycle_start and (epoch - self.cycle_start) % self.cycle_length == 0 ): return for name_p, param in self.model.named_parameters(): - mean = self.swag_stats[name_p + "_mean"] - squared_mean = self.swag_stats[name_p + "_sq_mean"] + mean = self.swag_stats[self.prfx + name_p + "_mean"] + squared_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] new_param = param.data.detach().cpu() mean = mean * self.num_models / ( @@ -112,16 +117,20 @@ def update_model(self, epoch: int) -> None: self.num_models + 1 ) + new_param**2 / (self.num_models + 1) - self.swag_stats[name_p + "_mean"] = mean - self.swag_stats[name_p + "_sq_mean"] = squared_mean + self.swag_stats[self.prfx + name_p + "_mean"] = mean + self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean if not self.diag_covariance: - covariance_sqrt = self.swag_stats[name_p + "_covariance_sqrt"] + covariance_sqrt = self.swag_stats[ + self.prfx + name_p + "_covariance_sqrt" + ] dev = (new_param - mean).view(-1, 1).t() covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) if self.num_models + 1 > self.max_num_models: covariance_sqrt = covariance_sqrt[1:, :] - self.swag_stats[name_p + "_covariance_sqrt"] = covariance_sqrt + self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( + covariance_sqrt + ) self.num_models += 1 @@ -184,30 +193,79 @@ def _fullrank_sample( new_sample = copy.deepcopy(self.model) for name_p, param in new_sample.named_parameters(): - mean = self.swag_stats[name_p + "_mean"] - sq_mean = self.swag_stats[name_p + "_sq_mean"] + mean = self.swag_stats[self.prfx + name_p + "_mean"] + sq_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] if not diagonal_covariance: - cov_mat_sqrt = self.swag_stats[name_p + "_covariance_sqrt"] + cov_mat_sqrt = self.swag_stats[ + self.prfx + name_p + "_covariance_sqrt" + ] - # draw diagonal variance sample var = torch.clamp(sq_mean - mean**2, self.var_clamp) var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) - # if covariance draw low rank sample if not diagonal_covariance: cov_sample = cov_mat_sqrt.t() @ torch.randn( (cov_mat_sqrt.size(0),) ) cov_sample /= (self.max_num_models - 1) ** 0.5 - rand_sample = var_sample + cov_sample.view_as(var_sample) - else: - rand_sample = var_sample + var_sample += cov_sample.view_as(var_sample) - sample = mean + scale**0.5 * rand_sample - param.data = sample.to(device=param.device, dtype=param.dtype) + sample = mean + scale**0.5 * var_sample + param.data = sample.to(device="cpu", dtype=param.dtype) return new_sample + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination |= self.swag_stats + + def state_dict( + self, *args, destination=None, prefix="", keep_vars=False + ) -> dict[str, Tensor]: + return self.swag_stats | super().state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + def _load(self, state_dict): + self.swag_stats = { + k: v for k, v in state_dict.items() if k in self.swag_stats + } + for k in self.swag_stats: + del state_dict[k] + self.samples = [ + self.sample(self.scale, self.diag_covariance) + for _ in range(self.num_estimators) + ] + self.need_bn_update = True + self.fit = True + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self._load(state_dict) + return super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def load_state_dict( + self, state_dict: Mapping, strict: bool = True, assign: bool = False + ): + self._load(state_dict) + return super().load_state_dict(state_dict, strict, assign) + def compute_logdet(self, block=False): raise NotImplementedError("Raise an issue if you need this feature") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 9f297897..a0c4844a 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -342,6 +342,11 @@ def on_test_start(self) -> None: self.id_logit_storage = [] self.ood_logit_storage = [] + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: """Forward pass of the model. @@ -467,7 +472,7 @@ def test_step( add_dataloader_idx=False, ) - if self.is_ensemble > 1: + if self.is_ensemble: self.test_id_ens_metrics.update(probs_per_est) if self.eval_ood: From 4a4eeac65927bbd36bb499e1d3658054287f32ff Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 16:09:48 +0200 Subject: [PATCH 31/57] :hammer: Refactor stochastic models --- auto_tutorials_source/tutorial_bayesian.py | 27 ++++++--- .../tutorial_mc_batch_norm.py | 5 +- .../layers/bayesian/bayes_conv.py | 26 ++++----- .../layers/bayesian/bayes_linear.py | 10 ++-- torch_uncertainty/layers/bayesian/sampler.py | 31 +++++----- torch_uncertainty/layers/distributions.py | 13 +++-- torch_uncertainty/models/__init__.py | 1 + torch_uncertainty/models/lenet.py | 20 +++---- torch_uncertainty/models/utils.py | 57 ------------------- torch_uncertainty/models/wrappers/__init__.py | 1 + .../models/wrappers/stochastic.py | 55 ++++++++++++++++++ torch_uncertainty/models/wrappers/swa.py | 4 +- 12 files changed, 131 insertions(+), 119 deletions(-) create mode 100644 torch_uncertainty/models/wrappers/stochastic.py diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 04f1202e..7ddd5c37 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -55,12 +55,12 @@ # We will use the Adam optimizer with the default learning rate of 0.001. -def optim_lenet(model: nn.Module) -> dict: +def optim_lenet(model: nn.Module): optimizer = optim.Adam( model.parameters(), lr=1e-3, ) - return {"optimizer": optimizer} + return optimizer # %% @@ -75,7 +75,7 @@ def optim_lenet(model: nn.Module) -> dict: trainer = Trainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1) # datamodule -root = Path("") / "data" +root = Path("data") datamodule = MNISTDataModule(root=root, batch_size=128, eval_ood=False) # model @@ -105,6 +105,7 @@ def optim_lenet(model: nn.Module) -> dict: num_classes=datamodule.num_classes, loss=loss, optim_recipe=optim_lenet(model), + is_ensemble=True ) # %% @@ -125,8 +126,10 @@ def optim_lenet(model: nn.Module) -> dict: # 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # -# Now that the model is trained, let's test it on MNIST - +# Now that the model is trained, let's test it on MNIST. +# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble +# and to the batch. As for TorchUncertainty 2.0, the ensemble dimension is merged with the batch dimension +# in this order (num_estimator x batch, classes). import matplotlib.pyplot as plt import numpy as np import torch @@ -148,14 +151,22 @@ def imshow(img): imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -logits = model(images) +# Put the model in eval mode to use several samples +model = model.eval() +logits = model(images).reshape(16, 128, 10) # num_estimators, batch_size, num_classes + +# We apply the softmax on the classes and average over the estimators probs = torch.nn.functional.softmax(logits, dim=-1) +avg_probs = probs.mean(dim=0) +var_probs = probs.std(dim=0) -_, predicted = torch.max(probs, 1) +_, predicted = torch.max(avg_probs, 1) print("Predicted digits: ", " ".join(f"{predicted[j]}" for j in range(4))) - +print("Std. dev. of the scores over the posterior samples", " ".join(f"{var_probs[j][predicted[j]]:.3}" for j in range(4))) # %% +# The scores should be quite certain. +# # References # ---------- # diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 12781e2b..3827472d 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -102,6 +102,9 @@ # .eval() to enable Monte Carlo batch normalization at inference. # In this tutorial, we plot the most uncertain images, i.e. the images for which # the variance of the predictions is the highest. +# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble +# and to the batch. As for TorchUncertainty 2.0, the ensemble dimension is merged with the batch dimension +# in this order (num_estimator x batch, classes). import matplotlib.pyplot as plt import numpy as np @@ -121,7 +124,7 @@ def imshow(img): images, labels = next(dataiter) routine.eval() -logits = routine(images).reshape(8, 128, 10) +logits = routine(images).reshape(8, 128, 10) # num_estimators, batch_size, num_classes probs = torch.nn.functional.softmax(logits, dim=-1) most_uncertain = sorted(probs.var(0).sum(-1).topk(4).indices) diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index d9fc4df4..3584ba77 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -11,7 +11,7 @@ ) from torch.nn.parameter import Parameter -from .sampler import PriorDistribution, TrainableDistribution +from .sampler import CenteredGaussianMixture, TrainableDistribution __all__ = ["BayesConv1d", "BayesConv2d", "BayesConv3d"] @@ -137,11 +137,11 @@ def __init__( self.register_parameter("bias_mu", None) self.register_parameter("bias_sigma", None) - self.weight_prior_dist = PriorDistribution( + self.weight_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) if bias: - self.bias_prior_dist = PriorDistribution( + self.bias_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) @@ -290,14 +290,14 @@ def forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = ( self.weight_sampler.log_posterior() + bias_lposterior ) - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) @@ -323,9 +323,7 @@ def __init__( device=None, dtype=None, ) -> None: - """Bayesian Conv2d Layer with Mixture of Normals prior and Normal - posterior. - """ + """Bayesian Conv2d Layer with Gaussian Mixture prior and Normal posterior.""" factory_kwargs = {"device": device, "dtype": dtype} kernel_size_ = _pair(kernel_size) stride_ = _pair(stride) @@ -389,14 +387,14 @@ def forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = ( self.weight_sampler.log_posterior() + bias_lposterior ) - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) @@ -422,9 +420,7 @@ def __init__( device=None, dtype=None, ) -> None: - """Bayesian Conv3d Layer with Mixture of Normals prior and Normal - posterior. - """ + """Bayesian Conv3d Layer with Gaussian mixture prior and Normal posterior.""" factory_kwargs = {"device": device, "dtype": dtype} kernel_size_ = _triple(kernel_size) stride_ = _triple(stride) @@ -488,13 +484,13 @@ def forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = ( self.weight_sampler.log_posterior() + bias_lposterior ) - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index 074f8554..2c9f15c4 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -3,7 +3,7 @@ from torch import Tensor, nn from torch.nn import init -from .sampler import PriorDistribution, TrainableDistribution +from .sampler import CenteredGaussianMixture, TrainableDistribution class BayesLinear(nn.Module): @@ -91,11 +91,11 @@ def __init__( self.bias_mu, self.bias_sigma ) - self.weight_prior_dist = PriorDistribution( + self.weight_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) if bias: - self.bias_prior_dist = PriorDistribution( + self.bias_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) @@ -122,12 +122,12 @@ def _forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = self.weight_sampler.log_posterior() + bias_lposterior - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return F.linear(inputs, weight, bias) diff --git a/torch_uncertainty/layers/bayesian/sampler.py b/torch_uncertainty/layers/bayesian/sampler.py index a512fad7..dd5e710a 100644 --- a/torch_uncertainty/layers/bayesian/sampler.py +++ b/torch_uncertainty/layers/bayesian/sampler.py @@ -18,9 +18,7 @@ def __init__( self.weight = None def sample(self) -> Tensor: - w_sample = torch.normal( - mean=0, std=1, size=self.mu.shape, device=self.mu.device - ) + w_sample = torch.randn(size=self.mu.shape, device=self.mu.device) self.sigma = torch.log1p(torch.exp(self.rho)).to(self.mu.device) self.weight = self.mu + self.sigma * w_sample return self.weight @@ -43,26 +41,27 @@ def log_posterior(self, weight: Tensor | None = None) -> Tensor: return -lposterior.sum() -class PriorDistribution(nn.Module): +class CenteredGaussianMixture(nn.Module): def __init__( self, sigma_1: float, sigma_2: float, pi: float, ) -> None: + """Create a mixture of two centered Gaussian distributions. + + Args: + sigma_1 (float): Standard deviation of the first Gaussian. + sigma_2 (float): Standard deviation of the second Gaussian. + pi (float): Mixing coefficient. + """ super().__init__() - self.pi = torch.tensor([pi, 1 - pi]) - self.mus = torch.zeros(2) - self.sigmas = torch.tensor([sigma_1, sigma_2]) + self.register_buffer("pi", torch.tensor([pi, 1 - pi])) + self.register_buffer("mus", torch.zeros(2)) + self.register_buffer("sigmas", torch.tensor([sigma_1, sigma_2])) - def log_prior(self, weight: Tensor) -> Tensor: - self.convert(weight.device) + def log_prob(self, weight: Tensor) -> Tensor: mix = distributions.Categorical(self.pi) normals = distributions.Normal(self.mus, self.sigmas) - self.distribution = distributions.MixtureSameFamily(mix, normals) - return self.distribution.log_prob(weight).sum() - - def convert(self, device) -> None: - self.pi = self.pi.to(device) - self.mus = self.mus.to(device) - self.sigmas = self.sigmas.to(device) + distribution = distributions.MixtureSameFamily(mix, normals) + return distribution.log_prob(weight).sum() diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 341cf5b9..956fb537 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -1,3 +1,5 @@ +from abc import ABC, abstractmethod + import torch.nn.functional as F from torch import Tensor, nn from torch.distributions import Distribution, Laplace, Normal @@ -5,18 +7,19 @@ from torch_uncertainty.utils.distributions import NormalInverseGamma -class _AbstractDist(nn.Module): +class AbstractDist(ABC, nn.Module): def __init__(self, dim: int) -> None: super().__init__() if dim < 1: raise ValueError(f"dim must be positive, got {dim}.") self.dim = dim + @abstractmethod def forward(self, x: Tensor) -> Distribution: - raise NotImplementedError + pass -class NormalLayer(_AbstractDist): +class NormalLayer(AbstractDist): """Normal distribution layer. Converts model outputs to Independent Normal distributions. @@ -46,7 +49,7 @@ def forward(self, x: Tensor) -> Normal: return Normal(loc, scale) -class LaplaceLayer(_AbstractDist): +class LaplaceLayer(AbstractDist): """Laplace distribution layer. Converts model outputs to Independent Laplace distributions. @@ -76,7 +79,7 @@ def forward(self, x: Tensor) -> Laplace: return Laplace(loc, scale) -class NormalInverseGammaLayer(_AbstractDist): +class NormalInverseGammaLayer(AbstractDist): """Normal-Inverse-Gamma distribution layer. Converts model outputs to Independent Normal-Inverse-Gamma distributions. diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 81dc9663..4b8964bc 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -7,6 +7,7 @@ SWAG, CheckpointEnsemble, MCDropout, + StochasticModel, deep_ensembles, mc_dropout, ) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 34a4da00..6804c6f9 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -8,7 +8,7 @@ from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear -from torch_uncertainty.models.utils import stochastic_model +from torch_uncertainty.models import StochasticModel __all__ = ["bayesian_lenet", "lenet", "packed_lenet"] @@ -83,16 +83,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc3(out) -@stochastic_model -class _StochasticLeNet(_LeNet): - pass - - def _lenet( stochastic: bool, in_channels: int, num_classes: int, layer_args: dict, + num_samples: int = 16, linear_layer: type[nn.Module] = nn.Linear, conv2d_layer: type[nn.Module] = nn.Conv2d, activation: Callable = nn.ReLU, @@ -100,9 +96,8 @@ def _lenet( groups: int = 1, dropout_rate: float = 0.0, last_layer_dropout: bool = False, -) -> _LeNet | _StochasticLeNet: - model = _LeNet if not stochastic else _StochasticLeNet - return model( +) -> _LeNet | StochasticModel: + model = _LeNet( in_channels=in_channels, num_classes=num_classes, linear_layer=linear_layer, @@ -114,6 +109,9 @@ def _lenet( dropout_rate=dropout_rate, last_layer_dropout=last_layer_dropout, ) + if stochastic: + return StochasticModel(model, num_samples) + return model def lenet( @@ -172,6 +170,7 @@ def packed_lenet( def bayesian_lenet( in_channels: int, num_classes: int, + num_samples: int = 16, prior_sigma_1: float | None = None, prior_sigma_2: float | None = None, prior_pi: float | None = None, @@ -181,7 +180,7 @@ def bayesian_lenet( norm: type[nn.Module] = nn.Identity, groups: int = 1, dropout_rate: float = 0.0, -) -> _StochasticLeNet: +) -> StochasticModel: layers_args = {} if prior_sigma_1 is not None: layers_args["prior_sigma_1"] = prior_sigma_1 @@ -196,6 +195,7 @@ def bayesian_lenet( return _lenet( stochastic=True, + num_samples=num_samples, in_channels=in_channels, num_classes=num_classes, linear_layer=BayesLinear, diff --git a/torch_uncertainty/models/utils.py b/torch_uncertainty/models/utils.py index cb0bcc02..87fd65ee 100644 --- a/torch_uncertainty/models/utils.py +++ b/torch_uncertainty/models/utils.py @@ -1,62 +1,5 @@ from torch import Tensor, nn -from torch_uncertainty.layers.bayesian import bayesian_modules - - -def stochastic_model(model: nn.Module) -> nn.Module: - """Decorator for stochastic models. - - When applied to a model, it adds the `sample`, `freeze` and `unfreeze` - methods. Use `freeze` to obtain deterministic outputs. Use unfreeze to - obtain stochastic outputs. `sample` get samples of the estimated posterior - distribution. - - Args: - model (nn.Module): PyTorch model. - """ - - def sample(self, num_samples: int = 1) -> list[dict]: - sampled_models = [{}] * num_samples - for module_name in self._modules: - module = self._modules[module_name] - if isinstance(module, bayesian_modules): - for model in sampled_models: - weight, bias = module.sample() - model[module_name + ".weight"] = weight - if bias is not None: - model[module_name + ".bias"] = bias - else: - for model in sampled_models: - state = module.state_dict() - if not len(state): # no parameter - break - # TODO: fix this - model |= { - module_name + "." + key: val - for key, val in module.state_dict().items() - } - return sampled_models - - model.sample = sample - - def freeze(self) -> None: - for module_name in self._modules: - module = self._modules[module_name] - if isinstance(module, bayesian_modules): - module.freeze() - - model.freeze = freeze - - def unfreeze(self) -> None: - for module_name in self._modules: - module = self._modules[module_name] - if isinstance(module, bayesian_modules): - module.unfreeze() - - model.unfreeze = unfreeze - - return model - class Backbone(nn.Module): def __init__(self, model: nn.Module, feat_names: list[str]) -> None: diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index d1895438..75f37e66 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -5,6 +5,7 @@ from .deep_ensembles import deep_ensembles from .ema import EMA from .mc_dropout import MCDropout, mc_dropout +from .stochastic import StochasticModel from .swa import SWA from .swag import SWAG diff --git a/torch_uncertainty/models/wrappers/stochastic.py b/torch_uncertainty/models/wrappers/stochastic.py new file mode 100644 index 00000000..6a424240 --- /dev/null +++ b/torch_uncertainty/models/wrappers/stochastic.py @@ -0,0 +1,55 @@ +import torch +from torch import Tensor, nn + +from torch_uncertainty.layers.bayesian import bayesian_modules + + +class StochasticModel(nn.Module): + def __init__(self, model: nn.Module, num_samples: int) -> None: + super().__init__() + self.model = model + self.num_samples = num_samples + + def eval_forward(self, x: Tensor) -> Tensor: + return torch.cat( + [self.model.forward(x) for _ in range(self.num_samples)], dim=0 + ) + + def forward(self, x: Tensor) -> Tensor: + if self.training: + return self.model.forward(x) + return self.eval_forward(x) + + def sample(self, num_samples: int = 1) -> list[dict]: + sampled_models = [{}] * num_samples + for module_name in self._modules: + module = self._modules[module_name] + if isinstance(module, bayesian_modules): + for model in sampled_models: + weight, bias = module.sample() + model[module_name + ".weight"] = weight + if bias is not None: + model[module_name + ".bias"] = bias + else: + for model in sampled_models: + state = module.state_dict() + if not len(state): # no parameter + break + # TODO: fix this + model |= { + module_name + "." + key: val + for key, val in module.state_dict().items() + } + return sampled_models + + def freeze(self) -> None: + for module_name in self._modules: + module = self._modules[module_name] + if isinstance(module, bayesian_modules): + module.freeze() + + def unfreeze(self) -> None: + for module_name in self._modules: + module = self._modules[module_name] + if isinstance(module, bayesian_modules): + module.unfreeze() diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py index 05aba48d..06a59689 100644 --- a/torch_uncertainty/models/wrappers/swa.py +++ b/torch_uncertainty/models/wrappers/swa.py @@ -61,12 +61,12 @@ def update_model(self, epoch: int) -> None: self.num_avgd_models += 1 self.need_bn_update = True - def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + def eval_forward(self, x: Tensor) -> Tensor: if self.swa_model is None: return self.model.forward(x) return self.swa_model.forward(x) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: if self.training: return self.model.forward(x) return self.eval_forward(x) From 737d862da7fef474ce9e95b48080691b2135acf8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 16:15:53 +0200 Subject: [PATCH 32/57] :bug: Fix Stochastic MLP error --- tests/models/test_stochastic_model.py | 113 +++++++++++++------------- torch_uncertainty/models/mlp.py | 20 ++--- 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/tests/models/test_stochastic_model.py b/tests/models/test_stochastic_model.py index b7ab5a75..3be9f475 100644 --- a/tests/models/test_stochastic_model.py +++ b/tests/models/test_stochastic_model.py @@ -1,75 +1,76 @@ -from torch import nn +# ruff: noqa: ERA001 +# from torch import nn -from torch_uncertainty.layers import BayesConv2d, BayesLinear -from torch_uncertainty.models.utils import stochastic_model +# from torch_uncertainty.layers import BayesConv2d, BayesLinear +# from torch_uncertainty.models.utils import stochastic_model -@stochastic_model -class DummyModelLinear(nn.Module): - """Dummy model for testing purposes.""" +# @stochastic_model +# class DummyModelLinear(nn.Module): +# """Dummy model for testing purposes.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.layer = BayesLinear(1, 10, 1) +# def __init__(self, *args, **kwargs) -> None: +# super().__init__(*args, **kwargs) +# self.layer = BayesLinear(1, 10, 1) - def forward(self, x): - return self.layer(x) +# def forward(self, x): +# return self.layer(x) -@stochastic_model -class DummyModelConv(nn.Module): - """Dummy conv model for testing purposes.""" +# @stochastic_model +# class DummyModelConv(nn.Module): +# """Dummy conv model for testing purposes.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.layer = BayesConv2d(1, 10, 1) +# def __init__(self, *args, **kwargs) -> None: +# super().__init__(*args, **kwargs) +# self.layer = BayesConv2d(1, 10, 1) - def forward(self, x): - return self.layer(x) +# def forward(self, x): +# return self.layer(x) -@stochastic_model -class DummyModelMix(nn.Module): - """Dummy mix model for testing purposes.""" +# @stochastic_model +# class DummyModelMix(nn.Module): +# """Dummy mix model for testing purposes.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.layer = BayesConv2d(1, 10, 1, bias=False) - self.relu = nn.ReLU() - self.layer2 = nn.Conv2d(10, 1, 1) +# def __init__(self, *args, **kwargs) -> None: +# super().__init__(*args, **kwargs) +# self.layer = BayesConv2d(1, 10, 1, bias=False) +# self.relu = nn.ReLU() +# self.layer2 = nn.Conv2d(10, 1, 1) - def forward(self, x): - y = self.relu(self.layer(x)) - return self.layer2(y) +# def forward(self, x): +# y = self.relu(self.layer(x)) +# return self.layer2(y) -class TestStochasticModel: - """Testing the StochasticModel decorator.""" +# class TestStochasticModel: +# """Testing the StochasticModel decorator.""" - def test_main(self): - model = DummyModelLinear() - model.freeze() - assert model.layer.frozen - model.unfreeze() - assert not model.layer.frozen +# def test_main(self): +# model = DummyModelLinear() +# model.freeze() +# assert model.layer.frozen +# model.unfreeze() +# assert not model.layer.frozen - model = DummyModelConv() - model.freeze() - assert model.layer.frozen - model.unfreeze() - assert not model.layer.frozen +# model = DummyModelConv() +# model.freeze() +# assert model.layer.frozen +# model.unfreeze() +# assert not model.layer.frozen - def test_mix(self): - model = DummyModelMix() - model.freeze() - assert model.layer.frozen - model.unfreeze() - assert not model.layer.frozen +# def test_mix(self): +# model = DummyModelMix() +# model.freeze() +# assert model.layer.frozen +# model.unfreeze() +# assert not model.layer.frozen - state = model.sample()[0] - keys = state.keys() - assert list(keys) == [ - "layer.weight", - "layer2.weight", - "layer2.bias", - ] +# state = model.sample()[0] +# keys = state.keys() +# assert list(keys) == [ +# "layer.weight", +# "layer2.weight", +# "layer2.bias", +# ] diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 1a50524f..d0fdee07 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -5,7 +5,7 @@ from torch_uncertainty.layers.bayesian import BayesLinear from torch_uncertainty.layers.packed import PackedLinear -from torch_uncertainty.models.utils import stochastic_model +from torch_uncertainty.models import StochasticModel __all__ = ["bayesian_mlp", "mlp", "packed_mlp"] @@ -84,29 +84,24 @@ def forward(self, x: Tensor) -> Tensor: return self.final_layer(self.layers[-1](x)) -@stochastic_model -class _StochasticMLP(_MLP): - pass - - def _mlp( stochastic: bool, in_features: int, num_outputs: int, hidden_dims: list[int], + num_samples: int = 16, layer_args: dict | None = None, layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, -) -> _MLP | _StochasticMLP: +) -> _MLP | StochasticModel: if layer_args is None: layer_args = {} if final_layer_args is None: final_layer_args = {} - model = _MLP if not stochastic else _StochasticMLP - return model( + model = _MLP( in_features=in_features, num_outputs=num_outputs, hidden_dims=hidden_dims, @@ -117,6 +112,9 @@ def _mlp( final_layer_args=final_layer_args, dropout_rate=dropout_rate, ) + if stochastic: + return StochasticModel(model, num_samples) + return model def mlp( @@ -194,13 +192,15 @@ def bayesian_mlp( in_features: int, num_outputs: int, hidden_dims: list[int], + num_samples: int = 16, activation: Callable = F.relu, final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, -) -> _StochasticMLP: +) -> StochasticModel: return _mlp( stochastic=True, + num_samples=num_samples, in_features=in_features, num_outputs=num_outputs, hidden_dims=hidden_dims, From 6f573324d3c53f2a47fb44a0a51b8f750e7faf5c Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 16:26:42 +0200 Subject: [PATCH 33/57] :books: Update documentation --- docs/source/api.rst | 4 +++- docs/source/references.rst | 17 ++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index ea87676f..8b5a4c7c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -164,8 +164,10 @@ Wrappers deep_ensembles CheckpointEnsemble EMA + StochasticModel SWA - MC_Dropout + SWAG + MCDropout mc_dropout Metrics diff --git a/docs/source/references.rst b/docs/source/references.rst index f2c4e71b..66cfdcd6 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -41,10 +41,10 @@ For Deep Evidential Regression, consider citing: * Paper: `NeurIPS 2020 `__. -Bayesian Neural Networks -^^^^^^^^^^^^^^^^^^^^^^^^ +Variational Inference Bayesian Neural Networks +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -For Bayesian Neural Networks, consider citing: +For Variational Inference Bayesian Neural Networks, consider citing: **Weight Uncertainty in Neural Networks** @@ -83,6 +83,17 @@ For Stochastic Weight Averaging, consider citing: * Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson* * Paper: `UAI 2018 `__. +Stochastic Weight Averaging Gaussian +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For Stochastic Weight Averaging Gaussian, consider citing: + +**A simple baseline for Bayesian uncertainty in deep learning** + +* Authors: *Wesley Maddox, Timur Garipov, Pavel Izmailov, Dmitry Vetrov, Andrew Gordon Wilson* +* Paper: `NeurIPS 2019 `__. + + CheckpointEnsemble ^^^^^^^^^^^^^^^^^^ From 00eb701c361722b16a619367c0e51ad0c70f1d3a Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 16:27:07 +0200 Subject: [PATCH 34/57] :books: Fix bugs in docs --- torch_uncertainty/post_processing/__init__.py | 1 + torch_uncertainty/routines/segmentation.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index 793e3637..2dfd35c4 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler +from .laplace import Laplace from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 6e1e11e2..4d1b2fbc 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -37,7 +37,7 @@ def __init__( num_classes (int): Number of classes in the segmentation task. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to ̀`1` (single model). + ensemble. Defaults to ``1`` (single model). optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the From 4e1e8af8e729da6f236eaa8c16892df46f4e533a Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 17:23:00 +0200 Subject: [PATCH 35/57] :white_check_mark: Add first battery of tests --- .github/workflows/run-tests.yml | 2 +- docs/source/api.rst | 2 +- tests/_dummies/model.py | 20 +--- tests/models/test_stochastic_model.py | 76 ---------------- tests/models/wrappers/__init__.py | 0 .../wrappers/test_checkpoint_ensemble.py | 26 ++++++ .../{ => wrappers}/test_deep_ensembles.py | 0 tests/models/wrappers/test_ema.py | 21 +++++ .../models/{ => wrappers}/test_mc_dropout.py | 0 .../models/wrappers/test_stochastic_model.py | 73 +++++++++++++++ tests/models/wrappers/test_swa.py | 91 +++++++++++++++++++ tests/post_processing/test_laplace.py | 21 +++++ .../models/wrappers/stochastic.py | 12 +-- torch_uncertainty/models/wrappers/swa.py | 2 +- torch_uncertainty/models/wrappers/swag.py | 25 +---- torch_uncertainty/post_processing/__init__.py | 2 +- torch_uncertainty/post_processing/laplace.py | 19 ++-- 17 files changed, 260 insertions(+), 132 deletions(-) delete mode 100644 tests/models/test_stochastic_model.py create mode 100644 tests/models/wrappers/__init__.py create mode 100644 tests/models/wrappers/test_checkpoint_ensemble.py rename tests/models/{ => wrappers}/test_deep_ensembles.py (100%) create mode 100644 tests/models/wrappers/test_ema.py rename tests/models/{ => wrappers}/test_mc_dropout.py (100%) create mode 100644 tests/models/wrappers/test_stochastic_model.py create mode 100644 tests/models/wrappers/test_swa.py create mode 100644 tests/post_processing/test_laplace.py diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index c404a106..a870c00e 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -65,7 +65,7 @@ jobs: if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - python3 -m pip install .[image,dev,docs] + python3 -m pip install .[all] - name: Check style & format if: steps.changed-files-specific.outputs.only_changed != 'true' diff --git a/docs/source/api.rst b/docs/source/api.rst index 8b5a4c7c..de590aa1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -245,7 +245,7 @@ Post-Processing Methods :nosignatures: :template: class_inherited.rst MCBatchNorm - Laplace + LaplaceApprox Scaling Methods ^^^^^^^^^^^^^^^ diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 2e29e2b5..2c57aec0 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -12,23 +12,16 @@ def __init__( in_channels: int, num_classes: int, dropout_rate: float, - with_linear: bool, last_layer: nn.Module, ) -> None: super().__init__() self.in_channels = in_channels self.dropout_rate = dropout_rate - if with_linear: - self.linear = nn.Linear( - 1, - num_classes, - ) - else: - self.out = nn.Linear( - 1, - num_classes, - ) + self.linear = nn.Linear( + 1, + num_classes, + ) self.last_layer = last_layer self.dropout = nn.Dropout(p=dropout_rate) @@ -92,7 +85,6 @@ def dummy_model( num_classes: int, dropout_rate: float = 0.0, with_feats: bool = True, - with_linear: bool = True, last_layer=None, ) -> _Dummy: """Dummy model for testing purposes. @@ -103,8 +95,6 @@ def dummy_model( num_estimators (int): Number of estimators in the ensemble. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. with_feats (bool, optional): Whether to include features. Defaults to True. - with_linear (bool, optional): Whether to include a linear layer. - Defaults to True. last_layer ([type], optional): Last layer of the model. Defaults to None. Returns: @@ -117,14 +107,12 @@ def dummy_model( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - with_linear=with_linear, last_layer=last_layer, ) return _Dummy( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - with_linear=with_linear, last_layer=last_layer, ) diff --git a/tests/models/test_stochastic_model.py b/tests/models/test_stochastic_model.py deleted file mode 100644 index 3be9f475..00000000 --- a/tests/models/test_stochastic_model.py +++ /dev/null @@ -1,76 +0,0 @@ -# ruff: noqa: ERA001 -# from torch import nn - -# from torch_uncertainty.layers import BayesConv2d, BayesLinear -# from torch_uncertainty.models.utils import stochastic_model - - -# @stochastic_model -# class DummyModelLinear(nn.Module): -# """Dummy model for testing purposes.""" - -# def __init__(self, *args, **kwargs) -> None: -# super().__init__(*args, **kwargs) -# self.layer = BayesLinear(1, 10, 1) - -# def forward(self, x): -# return self.layer(x) - - -# @stochastic_model -# class DummyModelConv(nn.Module): -# """Dummy conv model for testing purposes.""" - -# def __init__(self, *args, **kwargs) -> None: -# super().__init__(*args, **kwargs) -# self.layer = BayesConv2d(1, 10, 1) - -# def forward(self, x): -# return self.layer(x) - - -# @stochastic_model -# class DummyModelMix(nn.Module): -# """Dummy mix model for testing purposes.""" - -# def __init__(self, *args, **kwargs) -> None: -# super().__init__(*args, **kwargs) -# self.layer = BayesConv2d(1, 10, 1, bias=False) -# self.relu = nn.ReLU() -# self.layer2 = nn.Conv2d(10, 1, 1) - -# def forward(self, x): -# y = self.relu(self.layer(x)) -# return self.layer2(y) - - -# class TestStochasticModel: -# """Testing the StochasticModel decorator.""" - -# def test_main(self): -# model = DummyModelLinear() -# model.freeze() -# assert model.layer.frozen -# model.unfreeze() -# assert not model.layer.frozen - -# model = DummyModelConv() -# model.freeze() -# assert model.layer.frozen -# model.unfreeze() -# assert not model.layer.frozen - -# def test_mix(self): -# model = DummyModelMix() -# model.freeze() -# assert model.layer.frozen -# model.unfreeze() -# assert not model.layer.frozen - -# state = model.sample()[0] -# keys = state.keys() -# assert list(keys) == [ -# "layer.weight", -# "layer2.weight", -# "layer2.bias", -# ] diff --git a/tests/models/wrappers/__init__.py b/tests/models/wrappers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/wrappers/test_checkpoint_ensemble.py b/tests/models/wrappers/test_checkpoint_ensemble.py new file mode 100644 index 00000000..239d13af --- /dev/null +++ b/tests/models/wrappers/test_checkpoint_ensemble.py @@ -0,0 +1,26 @@ +import torch + +from tests._dummies.model import dummy_model +from torch_uncertainty.models import CheckpointEnsemble + + +class TestCheckpointEnsemble: + """Testing the CheckpointEnsemble class.""" + + def test_training(self): + ens = CheckpointEnsemble(dummy_model(1, 10)) + ens.eval() + ens(torch.randn(1, 1)) + + ens.train() + ens(torch.randn(1, 1)) + ens.update_model(0) + ens.eval() + ens(torch.randn(1, 1)) + + ens = CheckpointEnsemble(dummy_model(1, 10), use_final_checkpoint=False) + ens.train() + ens(torch.randn(1, 1)) + ens.update_model(0) + ens.eval() + ens(torch.randn(1, 1)) diff --git a/tests/models/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py similarity index 100% rename from tests/models/test_deep_ensembles.py rename to tests/models/wrappers/test_deep_ensembles.py diff --git a/tests/models/wrappers/test_ema.py b/tests/models/wrappers/test_ema.py new file mode 100644 index 00000000..52b81210 --- /dev/null +++ b/tests/models/wrappers/test_ema.py @@ -0,0 +1,21 @@ +import pytest +import torch +from torch import nn + +from tests._dummies.model import dummy_model +from torch_uncertainty.models import EMA + + +class TestEMA: + """Testing the EMA class.""" + + def test_training(self): + ema = EMA(dummy_model(1, 10), momentum=0.99) + ema.eval() + ema(torch.randn(1, 1)) + ema.train() + ema.update_model(0) + + def test_failures(self): + with pytest.raises(ValueError, match="must be in the range"): + EMA(nn.Module(), momentum=-1) diff --git a/tests/models/test_mc_dropout.py b/tests/models/wrappers/test_mc_dropout.py similarity index 100% rename from tests/models/test_mc_dropout.py rename to tests/models/wrappers/test_mc_dropout.py diff --git a/tests/models/wrappers/test_stochastic_model.py b/tests/models/wrappers/test_stochastic_model.py new file mode 100644 index 00000000..c7709650 --- /dev/null +++ b/tests/models/wrappers/test_stochastic_model.py @@ -0,0 +1,73 @@ +from torch import nn + +from torch_uncertainty.layers import BayesConv2d, BayesLinear +from torch_uncertainty.models import StochasticModel + + +class DummyModelLinear(nn.Module): + """Dummy model for testing purposes.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layer = BayesLinear(1, 10, 1) + + def forward(self, x): + return self.layer(x) + + +class DummyModelConv(nn.Module): + """Dummy conv model for testing purposes.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layer = BayesConv2d(1, 10, 1) + + def forward(self, x): + return self.layer(x) + + +class DummyModelMix(nn.Module): + """Dummy mix model for testing purposes.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layer = BayesConv2d(1, 10, 1, bias=False) + self.relu = nn.ReLU() + self.layer2 = nn.Conv2d(10, 1, 1) + + def forward(self, x): + y = self.relu(self.layer(x)) + return self.layer2(y) + + +class TestStochasticModel: + """Testing the StochasticModel decorator.""" + + def test_main(self): + model = StochasticModel(DummyModelLinear(), 2) + model.freeze() + assert model.model.layer.frozen + model.unfreeze() + assert not model.model.layer.frozen + + model = StochasticModel(DummyModelConv(), 2) + model.freeze() + assert model.model.layer.frozen + model.unfreeze() + assert not model.model.layer.frozen + + def test_mix(self): + model = StochasticModel(DummyModelMix(), 2) + model.freeze() + assert model.model.layer.frozen + model.unfreeze() + assert not model.model.layer.frozen + + state = model.sample()[0] + keys = state.keys() + print(list(keys)) + assert list(keys) == [ + "layer.weight", + "layer2.weight", + "layer2.bias", + ] diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py new file mode 100644 index 00000000..75fcdafe --- /dev/null +++ b/tests/models/wrappers/test_swa.py @@ -0,0 +1,91 @@ +import pytest +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +from tests._dummies.model import dummy_model +from torch_uncertainty.models import SWA, SWAG + + +class TestSWA: + """Testing the SWA class.""" + + def test_training(self): + dl = DataLoader(TensorDataset(torch.randn(1, 1)), batch_size=1) + swa = SWA(dummy_model(1, 10), cycle_start=1, cycle_length=1) + swa.eval() + swa(torch.randn(1, 1)) + + swa.train() + swa(torch.randn(1, 1)) + swa.update_model(0) + swa.update_bn(dl, "cpu") + + swa.update_model(1) + swa.update_bn(dl, "cpu") + + swa.eval() + swa(torch.randn(1, 1)) + + def test_failures(self): + with pytest.raises( + ValueError, match="`cycle_start` must be non-negative." + ): + SWA(nn.Module(), cycle_start=-1, cycle_length=1) + with pytest.raises( + ValueError, match="`cycle_length` must be strictly positive." + ): + SWA(nn.Module(), cycle_start=1, cycle_length=0) + + +class TestSWAG: + """Testing the SWAG class.""" + + def test_training(self): + dl = DataLoader(TensorDataset(torch.randn(1, 1)), batch_size=1) + swag = SWAG(dummy_model(1, 10), cycle_start=1, cycle_length=1) + swag.eval() + swag(torch.randn(1, 1)) + + swag.train() + swag(torch.randn(1, 1)) + swag.update_model(0) + swag.update_bn(dl, "cpu") + + swag.update_model(1) + swag.update_bn(dl, "cpu") + + swag.eval() + swag(torch.randn(1, 1)) + + swag = SWAG( + dummy_model(1, 10), + cycle_start=1, + cycle_length=1, + diag_covariance=False, + ) + swag.train() + swag.update_model(1) + + def test_state_dict(self): + mod = dummy_model(1, 10) + swag = SWAG(mod, cycle_start=1, cycle_length=1, num_estimators=3) + print(swag.state_dict()) + swag.load_state_dict(swag.state_dict()) + + def test_failures(self): + with pytest.raises( + NotImplementedError, match="Raise an issue if you need this feature" + ): + swag = SWAG(nn.Module(), scale=1, cycle_start=1, cycle_length=1) + swag.sample(scale=1, block=True) + with pytest.raises(ValueError, match="`scale` must be non-negative."): + SWAG(nn.Module(), scale=-1, cycle_start=1, cycle_length=1) + with pytest.raises( + ValueError, match="`max_num_models` must be non-negative." + ): + SWAG(nn.Module(), max_num_models=-1, cycle_start=1, cycle_length=1) + with pytest.raises( + ValueError, match="`var_clamp` must be non-negative. " + ): + SWAG(nn.Module(), var_clamp=-1, cycle_start=1, cycle_length=1) diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py new file mode 100644 index 00000000..c9aa8e72 --- /dev/null +++ b/tests/post_processing/test_laplace.py @@ -0,0 +1,21 @@ +import torch +from torch import nn +from torch.utils.data import TensorDataset + +from tests._dummies.model import dummy_model +from torch_uncertainty.post_processing import LaplaceApprox + + +class TestLaplace: + """Testing the LaplaceApprox class.""" + + def test_training(self): + ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10)) + la = LaplaceApprox( + task="classification", + model=dummy_model(1, 10, last_layer=nn.Linear(10, 10)), + ) + la.fit(ds) + la(torch.randn(1, 1)) + la = LaplaceApprox(task="classification") + la.set_model(dummy_model(1, 10)) diff --git a/torch_uncertainty/models/wrappers/stochastic.py b/torch_uncertainty/models/wrappers/stochastic.py index 6a424240..147b97cb 100644 --- a/torch_uncertainty/models/wrappers/stochastic.py +++ b/torch_uncertainty/models/wrappers/stochastic.py @@ -22,8 +22,8 @@ def forward(self, x: Tensor) -> Tensor: def sample(self, num_samples: int = 1) -> list[dict]: sampled_models = [{}] * num_samples - for module_name in self._modules: - module = self._modules[module_name] + for module_name in self.model._modules: + module = self.model._modules[module_name] if isinstance(module, bayesian_modules): for model in sampled_models: weight, bias = module.sample() @@ -43,13 +43,13 @@ def sample(self, num_samples: int = 1) -> list[dict]: return sampled_models def freeze(self) -> None: - for module_name in self._modules: - module = self._modules[module_name] + for module_name in self.model._modules: + module = self.model._modules[module_name] if isinstance(module, bayesian_modules): module.freeze() def unfreeze(self) -> None: - for module_name in self._modules: - module = self._modules[module_name] + for module_name in self.model._modules: + module = self.model._modules[module_name] if isinstance(module, bayesian_modules): module.unfreeze() diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py index 06a59689..5125b754 100644 --- a/torch_uncertainty/models/wrappers/swa.py +++ b/torch_uncertainty/models/wrappers/swa.py @@ -28,7 +28,7 @@ def __init__( Reference: Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. (2018). Averaging Weights Leads to Wider Optima and Better Generalization. - In NeurIPS 2018. + In UAI 2018. """ super().__init__() _swa_checks(cycle_start, cycle_length) diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 4a44978e..3c053135 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -226,7 +226,7 @@ def state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars ) - def _load(self, state_dict): + def _load_swag_stats(self, state_dict): self.swag_stats = { k: v for k, v in state_dict.items() if k in self.swag_stats } @@ -239,31 +239,10 @@ def _load(self, state_dict): self.need_bn_update = True self.fit = True - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - self._load(state_dict) - return super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - def load_state_dict( self, state_dict: Mapping, strict: bool = True, assign: bool = False ): - self._load(state_dict) + self._load_swag_stats(state_dict) return super().load_state_dict(state_dict, strict, assign) def compute_logdet(self, block=False): diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index 2dfd35c4..bc5a59cf 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,5 +1,5 @@ # ruff: noqa: F401 from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler -from .laplace import Laplace +from .laplace import LaplaceApprox from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 9a4940e4..e0b1203b 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -2,7 +2,7 @@ from typing import Literal from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset if util.find_spec("laplace"): from laplace import Laplace @@ -10,7 +10,7 @@ laplace_installed = True -class Laplace(nn.Module): +class LaplaceApprox(nn.Module): def __init__( self, task: Literal["classification", "regression"], @@ -21,6 +21,7 @@ def __init__( link_approx: Literal[ "mc", "probit", "bridge", "bridge_norm" ] = "probit", + batch_size: int = 256, ) -> None: """Laplace approximation for uncertainty estimation. @@ -38,12 +39,14 @@ def __init__( link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): how to approximate the classification link function for the `'glm'`. See the Laplace library for more details. Defaults to "probit". + batch_size (int, optional): batch size for the Laplace approximation. + Defaults to 256. Reference: Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. """ super().__init__() - if not laplace_installed: + if not laplace_installed: # coverage: ignore raise ImportError( "The laplace-torch library is not installed. Please install it via `pip install laplace-torch`." ) @@ -53,6 +56,7 @@ def __init__( self.task = task self.weight_subset = weight_subset self.hessian_struct = hessian_struct + self.batch_size = batch_size if model is not None: self._setup_model(model) @@ -60,16 +64,17 @@ def __init__( def _setup_model(self, model) -> None: self.la = Laplace( model=model, - task=self.task, - weight_subset=self.weight_subset, - hessian_struct=self.hessian_struct, + likelihood=self.task, + subset_of_weights=self.weight_subset, + hessian_structure=self.hessian_struct, ) def set_model(self, model: nn.Module) -> None: self._setup_model(model) def fit(self, dataset: Dataset) -> None: - self.la.fit(dataset=dataset) + dl = DataLoader(dataset, batch_size=self.batch_size) + self.la.fit(train_loader=dl) def forward( self, From 501b5d45ffabf4c06f98dfe5318524a7fb5ea638 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 17:31:58 +0200 Subject: [PATCH 36/57] :heavy_check_mark: Fix tests --- tests/_dummies/baseline.py | 10 +++++++--- tests/routines/test_classification.py | 10 +++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index d763800d..c87c71d5 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -7,7 +7,7 @@ NormalInverseGammaLayer, NormalLayer, ) -from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.models import EMA, SWA, deep_ensembles from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.routines import ( @@ -30,7 +30,6 @@ def __new__( baseline_type: str = "single", optim_recipe=optim_cifar10_resnet18, with_feats: bool = True, - with_linear: bool = True, ood_criterion: str = "msp", eval_ood: bool = False, eval_grouping_loss: bool = False, @@ -44,13 +43,18 @@ def __new__( mixup_alpha: float = 0, cutmix_alpha: float = 0, no_mixup_params: bool = False, + ema: bool = False, + swa: bool = False, ) -> ClassificationRoutine: model = dummy_model( in_channels=in_channels, num_classes=num_classes, with_feats=with_feats, - with_linear=with_linear, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) if not no_mixup_params: mixup_params = { "mixup_alpha": mixup_alpha, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index ebd6a736..888087f4 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -31,6 +31,7 @@ def test_one_estimator_binary(self): loss=nn.BCEWithLogitsLoss(), baseline_type="single", ood_criterion="msp", + ema=True, ) trainer.fit(model, dm) @@ -53,6 +54,7 @@ def test_two_estimators_binary(self): loss=nn.BCEWithLogitsLoss(), baseline_type="single", ood_criterion="logit", + swa=True, ) trainer.fit(model, dm) @@ -384,13 +386,7 @@ def test_classification_failures(self): eval_grouping_loss=True, ) - model = dummy_model(1, 1, 0, with_feats=False, with_linear=True) - with pytest.raises(ValueError): - ClassificationRoutine( - num_classes=10, model=model, loss=None, eval_grouping_loss=True - ) - - model = dummy_model(1, 1, 0, with_feats=True, with_linear=False) + model = dummy_model(1, 1, 0, with_feats=False) with pytest.raises(ValueError): ClassificationRoutine( num_classes=10, model=model, loss=None, eval_grouping_loss=True From d135612f4ce959fdd5faba0dc796ec35cc154952 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 17:37:19 +0200 Subject: [PATCH 37/57] :white_check_mark: Improve SWAG tests --- tests/models/wrappers/test_swa.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py index 75fcdafe..4d3171ef 100644 --- a/tests/models/wrappers/test_swa.py +++ b/tests/models/wrappers/test_swa.py @@ -51,10 +51,15 @@ def test_training(self): swag(torch.randn(1, 1)) swag.update_model(0) swag.update_bn(dl, "cpu") + swag(torch.randn(1, 1)) swag.update_model(1) swag.update_bn(dl, "cpu") + swag.update_model(2) + swag.update_bn(dl, "cpu") + swag(torch.randn(1, 1)) + swag.eval() swag(torch.randn(1, 1)) From f1b65467d79ae9edf54b1479c9a7cb054b5195b8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 17:38:56 +0200 Subject: [PATCH 38/57] :white_check_mark: Improve Stochastic tests --- tests/models/wrappers/test_stochastic_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/wrappers/test_stochastic_model.py b/tests/models/wrappers/test_stochastic_model.py index c7709650..80204cd6 100644 --- a/tests/models/wrappers/test_stochastic_model.py +++ b/tests/models/wrappers/test_stochastic_model.py @@ -1,3 +1,4 @@ +import torch from torch import nn from torch_uncertainty.layers import BayesConv2d, BayesLinear @@ -46,9 +47,11 @@ class TestStochasticModel: def test_main(self): model = StochasticModel(DummyModelLinear(), 2) model.freeze() + model(torch.randn(1, 1)) assert model.model.layer.frozen model.unfreeze() assert not model.model.layer.frozen + model(torch.randn(1, 1)) model = StochasticModel(DummyModelConv(), 2) model.freeze() From edbb88e55ae3e786e7f04a353499997674191618 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 17:43:05 +0200 Subject: [PATCH 39/57] :shirt: Minor changes --- .../metrics/classification/fpr95.py | 42 +++---------------- torch_uncertainty/routines/segmentation.py | 2 +- 2 files changed, 6 insertions(+), 38 deletions(-) diff --git a/torch_uncertainty/metrics/classification/fpr95.py b/torch_uncertainty/metrics/classification/fpr95.py index 87a1b93a..f7e4a660 100644 --- a/torch_uncertainty/metrics/classification/fpr95.py +++ b/torch_uncertainty/metrics/classification/fpr95.py @@ -1,43 +1,11 @@ import numpy as np import torch -from numpy.typing import ArrayLike from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08): - """Uses high precision for cumsum and checks that the final value matches - the sum. - - Args: - arr (ArrayLike): The array to be cumulatively summed as flat. - rtol (float, optional): Relative tolerance, see ``np.allclose``. - Defaults to 1e-05. - atol (float, optional): Absolute tolerance, see ``np.allclose``. - Defaults to 1e-08. - - Returns: - ArrayLike: The cumulatively summed array. - - Reference: - From https://github.com/hendrycks/anomaly-seg. - - TODO: Check if necessary. - """ - out = np.cumsum(arr, dtype=np.float64) - expected = np.sum(arr, dtype=np.float64) - if not np.allclose( - out[-1], expected, rtol=rtol, atol=atol - ): # coverage: ignore - raise RuntimeError( - "cumsum was found to be unstable: " - "its last element does not correspond to sum" - ) - return out - - class FPRx(Metric): is_differentiable: bool = False higher_is_better: bool = False @@ -53,6 +21,9 @@ def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: recall_level (float): The recall level at which to compute the FPR. pos_label (int): The positive label. kwargs: Additional arguments to pass to the metric class. + + Reference: + Inpired by https://github.com/hendrycks/anomaly-seg. """ super().__init__(**kwargs) @@ -82,13 +53,10 @@ def update(self, conf: Tensor, target: Tensor) -> None: self.targets.append(target) def compute(self) -> Tensor: - r"""Compute the actual False Positive Rate at x% Recall. + """Compute the actual False Positive Rate at x% Recall. Returns: Tensor: The value of the FPRx. - - Reference: - Inpired by https://github.com/hendrycks/anomaly-seg. """ conf = dim_zero_cat(self.conf).cpu().numpy() targets = dim_zero_cat(self.targets).cpu().numpy() @@ -120,7 +88,7 @@ def compute(self) -> Tensor: threshold_idxs = np.r_[distinct_value_indices, labels.shape[0] - 1] # accumulate the true positives with decreasing threshold - tps = stable_cumsum(labels)[threshold_idxs] + tps = np.cumsum(labels)[threshold_idxs] fps = 1 + threshold_idxs - tps # add one because of zero-based indexing thresholds = examples[threshold_idxs] diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 4d1b2fbc..3514f612 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -30,7 +30,7 @@ def __init__( log_plots: bool = False, num_calibration_bins: int = 15, ) -> None: - """Routine for training & testing on segmentation tasks. + r"""Routine for training & testing on segmentation tasks. Args: model (torch.nn.Module): Model to train. From fdbaf76b4832ce33cce02b50005d194aea0f626d Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Jun 2024 17:54:27 +0200 Subject: [PATCH 40/57] :white_check_mark: Finetune tests --- tests/layers/test_distributions.py | 6 ++++++ tests/models/test_lenet.py | 11 ++++++++++- tests/models/wrappers/test_stochastic_model.py | 1 + tests/post_processing/test_laplace.py | 12 +++++++++++- tests/post_processing/test_mc_batch_norm.py | 9 +++++++++ tests/test_optim_recipes.py | 14 +++++++++++--- 6 files changed, 48 insertions(+), 5 deletions(-) diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py index b1fbf4bd..85a67c6f 100644 --- a/tests/layers/test_distributions.py +++ b/tests/layers/test_distributions.py @@ -1,12 +1,18 @@ import pytest from torch_uncertainty.layers.distributions import ( + AbstractDist, LaplaceLayer, NormalLayer, ) class TestDistributions: + def test(self): + AbstractDist.__abstractmethods__ = set() + dist = AbstractDist(dim=1) + dist.forward(None) + def test_errors(self): with pytest.raises(ValueError): NormalLayer(-1, 1) diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index a0c6446a..8519ffdf 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -20,7 +20,16 @@ def test_main(self): packed_lenet(1, 1) bayesian_lenet(1, 1) - bayesian_lenet(1, 1, 1, 1, 1, 0, 1) + bayesian_lenet( + in_channels=1, + num_classes=1, + num_samples=1, + prior_sigma_1=1, + prior_sigma_2=1, + prior_pi=0, + mu_init=1, + sigma_init=1, + ) def test_errors(self): with pytest.raises(ValueError): diff --git a/tests/models/wrappers/test_stochastic_model.py b/tests/models/wrappers/test_stochastic_model.py index 80204cd6..76b60332 100644 --- a/tests/models/wrappers/test_stochastic_model.py +++ b/tests/models/wrappers/test_stochastic_model.py @@ -51,6 +51,7 @@ def test_main(self): assert model.model.layer.frozen model.unfreeze() assert not model.model.layer.frozen + model.eval() model(torch.randn(1, 1)) model = StochasticModel(DummyModelConv(), 2) diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py index c9aa8e72..6b798d6b 100644 --- a/tests/post_processing/test_laplace.py +++ b/tests/post_processing/test_laplace.py @@ -3,7 +3,17 @@ from torch.utils.data import TensorDataset from tests._dummies.model import dummy_model -from torch_uncertainty.post_processing import LaplaceApprox +from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing + + +class TestPostProcessing: + """Testing the PostProcessing class.""" + + def test_errors(self): + PostProcessing.__abstractmethods__ = set() + pp = PostProcessing(nn.Identity()) + pp.fit(None) + pp.forward(None) class TestLaplace: diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index fa72caee..0d1acdca 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -42,11 +42,20 @@ def test_main(self): stoch_model.eval() stoch_model(torch.randn(1, 1, 20, 20)) + stoch_model = MCBatchNorm( + num_estimators=2, convert=False, mc_batch_size=1 + ) + stoch_model.set_model(mc_model) + def test_errors(self): """Test errors.""" model = nn.Identity() with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=0, convert=True) + with pytest.raises( + ValueError, match="mc_batch_size must be a positive integer" + ): + MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=False) with pytest.raises(ValueError): diff --git a/tests/test_optim_recipes.py b/tests/test_optim_recipes.py index b71ac43f..a3e49530 100644 --- a/tests/test_optim_recipes.py +++ b/tests/test_optim_recipes.py @@ -2,9 +2,17 @@ import pytest import torch -from torch_uncertainty.optim_recipes import ( - get_procedure, -) +from torch_uncertainty.optim_recipes import FullSWALR, get_procedure + + +class TestFullSWALR: + def test_full_swa_lr(self): + FullSWALR( + torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), lr=1e-3), + swa_lr=1, + milestone=12, + anneal_epochs=5, + ) class TestOptProcedures: From a601ff9a26a50130e275e26eb4c2568476b6a020 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 10:00:43 +0200 Subject: [PATCH 41/57] :ok_hand: Take review comments into account --- auto_tutorials_source/tutorial_bayesian.py | 5 ++-- docs/source/references.rst | 32 ++++++++++++++-------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 7ddd5c37..68628d2a 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -128,7 +128,7 @@ def optim_lenet(model: nn.Module): # # Now that the model is trained, let's test it on MNIST. # Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble -# and to the batch. As for TorchUncertainty 2.0, the ensemble dimension is merged with the batch dimension +# and to the batch. As for TorchUncertainty 0.2.0, the ensemble dimension is merged with the batch dimension # in this order (num_estimator x batch, classes). import matplotlib.pyplot as plt import numpy as np @@ -165,7 +165,8 @@ def imshow(img): print("Predicted digits: ", " ".join(f"{predicted[j]}" for j in range(4))) print("Std. dev. of the scores over the posterior samples", " ".join(f"{var_probs[j][predicted[j]]:.3}" for j in range(4))) # %% -# The scores should be quite certain. +# Here, we show the variance of the top prediction. This is a non-standard but intuitive way to show the diversity of the predictions +# of the ensemble. Ideally, the variance should be high when the average top prediction is incorrect. # # References # ---------- diff --git a/docs/source/references.rst b/docs/source/references.rst index 66cfdcd6..3f9bd91f 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -15,7 +15,7 @@ For Deep Evidential Classification, consider citing: **Evidential Deep Learning to Quantify Classification Uncertainty** -* Authors: *Murat Sensoy, Lance Kaplan, Melih Kandemir* +* Authors: *Murat Sensoy, Lance Kaplan, and Melih Kandemir* * Paper: `NeurIPS 2018 `__. @@ -26,7 +26,7 @@ For Beta NLL in Deep Regression, consider citing: **On the Pitfalls of Heteroscedastic Uncertainty Estimation with Probabilistic Neural Networks** -* Authors: *Maximilian Seitzer, Arash Tavakoli, Dimitrije Antic, Georg Martius* +* Authors: *Maximilian Seitzer, Arash Tavakoli, Dimitrije Antic, and Georg Martius* * Paper: `ICLR 2022 `__. @@ -37,7 +37,7 @@ For Deep Evidential Regression, consider citing: **Deep Evidential Regression** -* Authors: *Alexander Amini, Wilko Schwarting, Ava Soleimany, Daniela Rus* +* Authors: *Alexander Amini, Wilko Schwarting, Ava Soleimany, and Daniela Rus* * Paper: `NeurIPS 2020 `__. @@ -80,7 +80,7 @@ For Stochastic Weight Averaging, consider citing: **Averaging Weights Leads to Wider Optima and Better Generalization** -* Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson* +* Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson* * Paper: `UAI 2018 `__. Stochastic Weight Averaging Gaussian @@ -101,9 +101,19 @@ For CheckpointEnsemble, consider citing: **Checkpoint Ensembles: Ensemble Methods from a Single Training Process** -* Authors: *Hugh Chen, Scott Lundberg, Su-In Lee* +* Authors: *Hugh Chen, Scott Lundberg, and Su-In Lee* * Paper: `ArXiv `__. +SnapshotEnsemble +^^^^^^^^^^^^^^^^ + +For SnapshotEnsemble, consider citing: + +**Snapshot Ensembles: Train 1, get M for free** + +* Authors: *Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger +* Paper: `ICLR 2017 `__. + BatchEnsemble ^^^^^^^^^^^^^ @@ -177,7 +187,7 @@ For RegMixup, consider citing: **RegMixup: Mixup as a Regularizer Can Surprisingly Improve Accuracy and Out Distribution Robustness** -* Authors: *Francesco Pinto, Harry Yang, Ser-Nam Lim, Philip H.S. Torr, Puneet K. Dokania* +* Authors: *Francesco Pinto, Harry Yang, Ser-Nam Lim, Philip H.S. Torr, and Puneet K. Dokania* * Paper: `NeurIPS 2022 `__. MixupIO @@ -230,7 +240,7 @@ For Laplace Approximation, consider citing: **Laplace Redux - Effortless Bayesian Deep Learning** -* Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, Philipp Hennig* +* Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig* * Paper: `NeurIPS 2021 `__. Metrics @@ -255,7 +265,7 @@ For the adaptive calibration error, consider citing: **Measuring Calibration in Deep Learning** -* Authors: Jeremy Nixon, Mike Dusenberry, Ghassen Jerfel, Timothy Nguyen, Jeremiah Liu, Linchuan Zhang, Dustin Tran +* Authors: *Jeremy Nixon, Mike Dusenberry, Ghassen Jerfel, Timothy Nguyen, Jeremiah Liu, Linchuan Zhang, and Dustin Tran* * Paper: `CVPRW 2019 `__. Area Under the Risk-Coverage curve @@ -265,7 +275,7 @@ For the area under the risk-coverage curve, consider citing: **Selective classification for deep neural networks** -* Authors: Yonatan Geifman, Ran El-Yaniv +* Authors: *Yonatan Geifman and Ran El-Yaniv* * Paper: `NeurIPS 2017 `__. Grouping Loss @@ -335,7 +345,7 @@ CIFAR-10 N / CIFAR-100 N **Learning with Noisy Labels Revisited: A Study Using Real-World Human Annotations** -* Authors: *Jiaheng Wei, Zhaowei Zhu, Hao Cheng, Tongliang Liu, Gang Niu, Yang Liu* +* Authors: *Jiaheng Wei, Zhaowei Zhu, Hao Cheng, Tongliang Liu, Gang Niu, and Yang Liu* * Paper: `ICLR 2022 `__. SVHN @@ -400,7 +410,7 @@ MUAD **MUAD: Multiple Uncertainties for Autonomous Driving Dataset** -* Authors: Gianni Franchi, Xuanlong Yu, Andrei Bursuc, et al.* +* Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, et al.* * Paper: `BMVC 2022 __` Architectures From 3ea90e911a601cfaf092a2bc05b0cde2e3ad188d Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 10:45:22 +0200 Subject: [PATCH 42/57] :books: Improve documentation & tutorials --- ...{tutorial_corruptions.py => tutorial_corruption.py} | 4 ++-- auto_tutorials_source/tutorial_der_cubic.py | 6 ++---- auto_tutorials_source/tutorial_mc_batch_norm.py | 4 ++-- auto_tutorials_source/tutorial_mc_dropout.py | 10 ++++------ auto_tutorials_source/tutorial_pe_cifar10.py | 5 +++-- auto_tutorials_source/tutorial_scaler.py | 4 ++-- docs/source/api.rst | 2 +- docs/source/references.rst | 2 +- torch_uncertainty/routines/segmentation.py | 2 +- 9 files changed, 18 insertions(+), 21 deletions(-) rename auto_tutorials_source/{tutorial_corruptions.py => tutorial_corruption.py} (95%) diff --git a/auto_tutorials_source/tutorial_corruptions.py b/auto_tutorials_source/tutorial_corruption.py similarity index 95% rename from auto_tutorials_source/tutorial_corruptions.py rename to auto_tutorials_source/tutorial_corruption.py index d20e4f19..9e4e7a10 100644 --- a/auto_tutorials_source/tutorial_corruptions.py +++ b/auto_tutorials_source/tutorial_corruption.py @@ -1,6 +1,6 @@ """ -Image Corruptions -================= +Corrupting Images with TorchUncertainty to Benchmark Robustness +=============================================================== This tutorial shows the impact of the different corruptions available in the TorchUncertainty library. These corruptions were first proposed in the paper diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index b77a0a4d..62afd3ef 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -55,9 +55,7 @@ def optim_regression( lr=learning_rate, weight_decay=0, ) - return { - "optimizer": optimizer, - } + return optimizer # %% @@ -69,7 +67,7 @@ def optim_regression( # Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model # as the parameters of a Normal Inverse Gamma distribution. -trainer = Trainer(accelerator="cpu", max_epochs=50)#, enable_progress_bar=False) +trainer = Trainer(accelerator="cpu", max_epochs=50) #, enable_progress_bar=False) # dataset train_ds = Cubic(num_samples=1000) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 3827472d..140559fc 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -76,7 +76,7 @@ # `trainer.test`. trainer.fit(model=routine, datamodule=datamodule) -trainer.test(model=routine, datamodule=datamodule); +perf = trainer.test(model=routine, datamodule=datamodule) # %% # 5. Wrapping the Model in a MCBatchNorm @@ -93,7 +93,7 @@ routine.model, num_estimators=8, convert=True, mc_batch_size=16 ) routine.model.fit(datamodule.train) -routine.eval(); +routine = routine.eval() # To avoid prints # %% # 6. Testing the Model diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index cd410526..2d5bf925 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -22,8 +22,8 @@ - the Trainer from Lightning - the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules - the model: LeNet, which lies in torch_uncertainty.models -- the MC Dropout wrapper: mc_dropout, which lies in torch_uncertainty.models -- the classification training routine in the torch_uncertainty.routines +- the MC Dropout wrapper: mc_dropout, from torch_uncertainty.models.wrappers +- the classification training & evaluation routine in the torch_uncertainty.routines - an optimization recipe in the torch_uncertainty.optim_recipes module. We also need import the neural network utils within `torch.nn`. @@ -76,16 +76,14 @@ # This is a classification problem, and we use CrossEntropyLoss as the likelihood. # We define the training routine using the classification training routine from # torch_uncertainty.routines.classification. We provide the number of classes -# and channels, the optimizer wrapper, the dropout rate, and the number of -# forward passes to perform through the network, as well as all the default -# arguments. +# and channels, the optimizer wrapper, and the dropout rate. routine = ClassificationRoutine( num_classes=datamodule.num_classes, model=mc_model, loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18(mc_model), - num_estimators=16, + is_ensemble=True, ) # %% diff --git a/auto_tutorials_source/tutorial_pe_cifar10.py b/auto_tutorials_source/tutorial_pe_cifar10.py index 57b6b51f..d3a233cd 100644 --- a/auto_tutorials_source/tutorial_pe_cifar10.py +++ b/auto_tutorials_source/tutorial_pe_cifar10.py @@ -45,6 +45,7 @@ import torch import torchvision import torchvision.transforms as transforms +from torch.utils.data import DataLoader torch.set_num_threads(1) @@ -69,14 +70,14 @@ trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) -trainloader = torch.utils.data.DataLoader( +trainloader = DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) -testloader = torch.utils.data.DataLoader( +testloader = DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index 75f1953c..ceaaa036 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -25,7 +25,7 @@ If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag. """ - +# %% from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.metrics import CalibrationError from torch_uncertainty.models.resnet import resnet @@ -114,7 +114,7 @@ # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler(model=model) -scaled_model = scaled_model.fit(calibration_set=cal_dataset) +scaled_model.fit(calibration_set=cal_dataset) # %% # 6. Iterating Again to Compute the Improved ECE diff --git a/docs/source/api.rst b/docs/source/api.rst index de590aa1..ba4e3ef5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -243,7 +243,7 @@ Post-Processing Methods .. autosummary:: :toctree: generated/ :nosignatures: - :template: class_inherited.rst + :template: class.rst MCBatchNorm LaplaceApprox diff --git a/docs/source/references.rst b/docs/source/references.rst index 3f9bd91f..89829c16 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -111,7 +111,7 @@ For SnapshotEnsemble, consider citing: **Snapshot Ensembles: Train 1, get M for free** -* Authors: *Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger +* Authors: *Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger* * Paper: `ICLR 2017 `__. BatchEnsemble diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 3514f612..b372a04d 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -45,7 +45,7 @@ def __init__( metric_subsampling_rate (float, optional): The rate of subsampling for the memory consuming metrics. Defaults to ``1e-2``. log_plots (bool, optional): Indicates whether to log plots from - metrics. Defaults to ``False` + metrics. Defaults to ``False``. num_calibration_bins (int, optional): Number of bins to compute calibration metrics. Defaults to ``15``. From 84fb04fb315a5b18e71cb952aac921b8396cdbe2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 10:47:22 +0200 Subject: [PATCH 43/57] :book: Add a tutorial on Packed-Ensembles --- .../tutorial_from_de_to_pe.py | 417 ++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 auto_tutorials_source/tutorial_from_de_to_pe.py diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py new file mode 100644 index 00000000..75f39f22 --- /dev/null +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -0,0 +1,417 @@ +"""Improved Ensemble parameter-efficiency with Packed-Ensembles +============================================================ + +*This tutorial is adapted from a notebook part of a lecture given at the [Helmholtz AI Conference](https://haicon24.de/) by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.* + +In this notebook will work on the MNIST dataset that was introduced by Corinna Cortes, Christopher J.C. Burges, and later modified by Yann LeCun in the foundational paper: + +- [Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE.](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) + +The MNIST dataset consists of 70 000 images of handwritten digits from 0 to 9. The images are grayscale and 28x28-pixel sized. The task is to classify the images into their respective digits. The dataset can be automatically downloaded using the `torchvision` library. + +In this notebook, we will train a model and an ensemble on this task and evaluate their performance. The performance will consist in the following metrics: +- Accuracy: the proportion of correctly classified images +- Brier score: a measure of the quality of the predicted probabilities +- Calibration error: a measure of the calibration of the predicted probabilities + +Throughout this notebook, we abstract the training and evaluation process using [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) +and [TorchUncertainty](https://torch-uncertainty.github.io/). + +Similarly to keras for tensorflow, PyTorch Lightning is a high-level interface for PyTorch that simplifies the training and evaluation process using a Trainer. +TorchUncertainty is partly built on top of PyTorch Lightning and provides tools to train and evaluate models with uncertainty quantification. + +TorchUncertainty includes datamodules that handle the data loading and preprocessing. We don't use them here for tutorial purposes. +""" +# %% +# 1. Download, instantiate and visualize the datasets +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The dataset is automatically downloaded using torchvision. We will then visualize a few images to get a sense of the data. + +# Create the transforms for the images +import torch +import torchvision.transforms as T + +# We set the number of epochs to some low value for the sake of time +max_epochs = 2 + +train_transform = T.Compose( + [ + T.ToTensor(), + # We perform random cropping as data augmentation + T.RandomCrop(28, padding=4), + # As for the MNIST1d dataset, we normalize the data + T.Normalize((0.1307,), (0.3081,)), + ] +) +test_transform = T.Compose( + [ + T.Grayscale(num_output_channels=1), + T.ToTensor(), + T.CenterCrop(28), + T.Normalize((0.1307,), (0.3081,)), + ] +) + +# Download and instantiate the dataset +from torch.utils.data import Subset +from torchvision.datasets import MNIST, FashionMNIST + +train_data = MNIST( + root="./data/", download=True, train=True, transform=train_transform +) +test_data = MNIST(root="./data/", train=False, transform=test_transform) +# We only take the first 10k images to have the same number of samples as the test set using torch Subsets +ood_data = Subset( + FashionMNIST(root="./data/", download=True, transform=test_transform), + indices=range(10000), +) + +# Create the corresponding dataloaders +from torch.utils.data import DataLoader + +train_dl = DataLoader(train_data, batch_size=32, shuffle=True) +test_dl = DataLoader(test_data, batch_size=32, shuffle=False) +ood_dl = DataLoader(ood_data, batch_size=32, shuffle=False) + +# %% +# You could replace all this cell by simply loading the MNIST datamodule from TorchUncertainty. +# Now, let's visualize a few images from the dataset. For this task, we use the viz_data dataset that applies no transformation to the images. + +# Datasets without transformation to visualize the unchanged data +viz_data = MNIST(root="./data/", train=False) +ood_viz_data = FashionMNIST(root="./data/", download=True) + +print("In distribution data:") +viz_data[0][0] +# %% +print("Out of distribution data:") +ood_viz_data[0][0] + +# %% +# 2. Create & train the model +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We will create a simple convolutional neural network (CNN): the LeNet model (also introduced by LeCun). +import torch.nn as nn +import torch.nn.functional as F + + +class LeNet(nn.Module): + def __init__( + self, + in_channels: int, + num_classes: int, + ) -> None: + super().__init__() + self.conv1 = nn.Conv2d(in_channels, 6, (5, 5)) + self.conv2 = nn.Conv2d(6, 16, (5, 5)) + self.pooling = nn.AdaptiveAvgPool2d((4, 4)) + self.fc1 = nn.Linear(256, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = torch.flatten(out, 1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + return self.fc3(out) # No softmax in the model! + + +# Instantiate the model, the images are in grayscale so the number of channels is 1 +model = LeNet(in_channels=1, num_classes=10) + +# %% +# We now need to define the optimization recipe: +# - the optimizer, here the standard stochastic gradient descent (SGD) with a learning rate of 0.05 +# - the scheduler, here cosine annealing. + + +def optim_recipe(model, lr_mult: float = 1.0): + optimizer = torch.optim.SGD(model.parameters(), lr=0.05 * lr_mult) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + return {"optimizer": optimizer, "scheduler": scheduler} + + +# %% +# To train the model, we use [TorchUncertainty](https://torch-uncertainty.github.io/), a library that we have developed to ease +# the training and evaluation of models with uncertainty. You can have a look at the +# [documentation](https://torch-uncertainty.github.io/) and the [code](https://github.com/ENSTA-U2IS-AI/torch-uncertainty). +# +# **Note:** To train supervised classification models we most often use the cross-entropy loss. +# With weight-decay, this minimizing this loss amounts to finding a Maximum a posteriori (MAP) estimate of the model parameters. +# This means that the model is trained to predict the most likely class for each input. + + +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TUTrainer + +# Create the trainer that will handle the training +trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) + +# The routine is a wrapper of the model that contains the training logic with the metrics, etc +routine = ClassificationRoutine( + num_classes=10, + model=model, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_recipe(model), + eval_ood=True, +) + +# In practice, avoid performing the validation on the test set (if you do model selection) +trainer.fit(routine, train_dataloaders=train_dl, val_dataloaders=test_dl) + +# %% +# Evaluate the trained model on the test set - pay attention to the cls/Acc metric +perf = trainer.test(routine, dataloaders=[test_dl, ood_dl]) + +# %% +# This table provides a lot of information: +# +# **OOD Detection: Binary Classification MNIST vs. FashionMNIST** +# - AUPR/AUROC/FPR95: Measures the quality of the OOD detection. The higher the better for AUPR and AUROC, the lower the better for FPR95. +# +# **Calibration: Reliability of the Predictions** +# - ECE: Expected Calibration Error. The lower the better. +# - aECE: Adaptive Expected Calibration Error. The lower the better. (~More precise version of the ECE) +# +# **Classification Performance** +# - Accuracy: The ratio of correctly classified images. The higher the better. +# - Brier: The quality of the predicted probabilities (Mean Squared Error of the predictions vs. ground-truth). The lower the better. +# - Negative Log-Likelihood: The value of the loss on the test set. The lower the better. +# +# **Selective Classification & Grouping Loss** +# - We talk about these points later in the "To go further" section. +# +# 3. Training an ensemble of models with TorchUncertainty +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# You have two options here, you can either train the ensemble directly if you have enough memory, +# otherwise, you can train independent models and do the ensembling during the evaluation (sometimes called inference). +# +# In this case, we will do it sequentially. In this tutorial, you have the choice between training multiple models, +# which will take time if you have no GPU, or downloading the pre-trained models that we have prepared for you. +# +# Training the ensemble +# +# To train the ensemble, you will have to use the "deep_ensembles" function from TorchUncertainty, which will +# replicate and change the initialization of your networks to ensure diversity. + +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.transforms import RepeatTarget + +# Create the ensemble model +ensemble = deep_ensembles( + LeNet(in_channels=1, num_classes=10), + num_estimators=2, + task="classification", + reset_model_parameters=True, +) + +trainer = TUTrainer(accelerator="cpu", max_epochs=1) +ens_routine = ClassificationRoutine( + is_ensemble=True, + num_classes=10, + model=ensemble, + loss=nn.CrossEntropyLoss(), # The loss for the training + format_batch_fn=RepeatTarget( + 2 + ), # How to handle the targets when comparing the predictions + optim_recipe=optim_recipe( + ensemble, 2.0 + ), # The optimization scheme with the optimizer and the scheduler as a dictionnary + eval_ood=True, # We want to evaluate the OOD-related metrics +) +trainer.fit(ens_routine, train_dataloaders=train_dl, val_dataloaders=test_dl) +ens_perf = trainer.test(ens_routine, dataloaders=[test_dl, ood_dl]) + +# %% +# The results are not comparable since we only trained the ensemble for one epoch to reduce GitHub's cpu usage. +# Feel free to run the notebook on your machine for a longer duration. +# +# We need to multiply the learning rate by 2 to account for the fact that we have 4 models +# in the ensemble and that we average the loss over all the predictions. +# +# #### Downloading the pre-trained models +# +# We have put the pre-trained models on Hugging Face that you can download with the utility function +# "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not +# comparable to the all the other models trained in this notebook. The pretrained models can be seen +# [here](https://huggingface.co/ENSTA-U2IS/tutorial-models) and TorchUncertainty's are [here](https://huggingface.co/torch-uncertainty). + +from torch_uncertainty.utils.hub import hf_hub_download + +all_models = [] +for i in range(8): + hf_hub_download( + repo_id="ENSTA-U2IS/tutorial-models", + filename=f"version_{i}.ckpt", + local_dir="./models/", + ) + model = LeNet(in_channels=1, num_classes=10) + state_dict = torch.load(f"./models/version_{i}.ckpt", map_location="cpu")[ + "state_dict" + ] + state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + model.load_state_dict(state_dict) + all_models.append(model) + +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.transforms import RepeatTarget + +ensemble = deep_ensembles( + all_models, + num_estimators=None, + task="classification", + reset_model_parameters=True, +) + +ens_routine = ClassificationRoutine( + is_ensemble=True, + num_classes=10, + model=ensemble, + loss=nn.CrossEntropyLoss(), # The loss for the training + format_batch_fn=RepeatTarget( + 8 + ), # How to handle the targets when comparing the predictions + optim_recipe=None, # No optim recipe as the model is already trained + eval_ood=True, # We want to evaluate the OOD-related metrics +) + +trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) + +ens_perf = trainer.test(ens_routine, dataloaders=[test_dl, ood_dl]) + +# %% +# 4. From Deep Ensembles to Packed-Ensembles +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the paper [Packed-Ensembles for Efficient Uncertainty Quantification](https://arxiv.org/abs/2210.09184) +# published at the International Conference on Learning Representations (ICLR) last year, we introduced a +# modification of Deep Ensembles to make it more computationally-efficient. The idea is to pack the ensemble +# members into a single model, which allows us to train the ensemble in a single forward pass. +# This modification is particularly useful when the ensemble size is large, as it is often the case in practice. +# +# We will need to update the model and replace the layers with their Packed equivalents. You can find the +# documentation of the Packed-Linear layer [here](https://torch-uncertainty.github.io/generated/torch_uncertainty.layers.PackedLinear.html), +# and the Packed-Conv2D, [here](https://torch-uncertainty.github.io/generated/torch_uncertainty.layers.PackedLinear.html). + +import torch +import torch.nn as nn +from einops import rearrange + +from torch_uncertainty.layers import PackedConv2d, PackedLinear + + +class PackedLeNet(nn.Module): + def __init__( + self, + in_channels: int, + num_classes: int, + alpha: int, + num_estimators: int, + ) -> None: + super().__init__() + self.num_estimators = num_estimators + self.conv1 = PackedConv2d( + in_channels, + 6, + (5, 5), + alpha=alpha, + num_estimators=num_estimators, + first=True, + ) + self.conv2 = PackedConv2d( + 6, + 16, + (5, 5), + alpha=alpha, + num_estimators=num_estimators, + ) + self.pooling = nn.AdaptiveAvgPool2d((4, 4)) + self.fc1 = PackedLinear( + 256, 120, alpha=alpha, num_estimators=num_estimators + ) + self.fc2 = PackedLinear( + 120, 84, alpha=alpha, num_estimators=num_estimators + ) + self.fc3 = PackedLinear( + 84, + num_classes, + alpha=alpha, + num_estimators=num_estimators, + last=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = rearrange( + out, "e (m c) h w -> (m e) c h w", m=self.num_estimators + ) + out = torch.flatten(out, 1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + return self.fc3(out) # Again, no softmax in the model + + +# Instantiate the model, the images are in grayscale so the number of channels is 1 +packed_model = PackedLeNet( + in_channels=1, num_classes=10, alpha=2, num_estimators=4 +) + +# Create the trainer that will handle the training +trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) + +# The routine is a wrapper of the model that contains the training logic with the metrics, etc +packed_routine = ClassificationRoutine( + is_ensemble=True, + num_classes=10, + model=packed_model, + loss=nn.CrossEntropyLoss(), + format_batch_fn=RepeatTarget(4), + optim_recipe=optim_recipe(packed_model, 4.0), + eval_ood=True, +) + +# In practice, avoid performing the validation on the test set +trainer.fit(packed_routine, train_dataloaders=train_dl, val_dataloaders=test_dl) + +packed_perf = trainer.test(packed_routine, dataloaders=[test_dl, ood_dl]) + +# %% +# The training time should be approximately similar to the one of the single model that you trained before. However, please note that we are working with very small models, hence completely underusing the your GPU. As such, the training time is not representative of what you would observe with larger models. +# +# You can read more on Packed-Ensembles in the [paper](https://arxiv.org/abs/2210.09184) or the [Medium](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873) post. +# +# To Go Further & More Concepts of Uncertainty in ML +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# **Question 1:** Have a look at the models in the "lightning_logs". If you are on your own machine, try to visualize the learning curves with `tensorboard --logdir lightning_logs`. +# +# **Question 2:** Add a cell below and try to find the errors made by packed-ensembles on the test set. Visualize the errors and their labels and look at the predictions of the different sub-models. Are they similar? Can you think of uncertainty scores that could help you identify these errors? +# +# Selective Classification +# ^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Selective classification or "prediction with rejection" is a paradigm in uncertainty-aware machine learning where the model can decide not to make a prediction if the confidence score given by the model is below some pre-computed threshold. This can be useful in real-world applications where the cost of making a wrong prediction is high. +# +# In constrast to calibration, the values of the confidence scores are not important, only the order of the scores. *Ideally, the best model will order all the correct predictions first, and all the incorrect predictions last.* In this case, there will be a threshold so that all the predictions above the threshold are correct, and all the predictions below the threshold are incorrect. +# +# In TorchUncertainty, we look at 3 different metrics for selective classification: +# - **AURC**: The area under the Risk (% of errors) vs. Coverage (% of classified samples) curve. This curve expresses how the risk of the model evolves as we increase the coverage (the proportion of predictions that are above the selection threshold). This metric will be minimized by a model able to perfectly separate the correct and incorrect predictions. +# +# The following metrics are computed at a fixed risk and coverage level and that have practical interests. The idea of these metrics is that you can set the selection threshold to achieve a certain level of risk and coverage, as required by the technical constraints of your application: +# - **Coverage at 5% Risk**: The proportion of predictions that are above the selection threshold when it is set for the risk to egal 5%. Set the risk threshold to your application constraints. The higher the better. +# - **Risk at 80% Coverage**: The proportion of errors when the coverage is set to 80%. Set the coverage threshold to your application constraints. The lower the better. +# +# Grouping Loss +# ^^^^^^^^^^^^^ +# +# The grouping loss is a measure of uncertainty orthogonal to calibration. Have a look at [this paper](https://arxiv.org/abs/2210.16315) to learn about it. Check out the small library [GLest](https://github.com/aperezlebel/glest) to learn more about it. TorchUncertainty includes a wrapper of the library to compute the grouping loss with eval_grouping_loss parameter. From f6fb41c541cd1234aed2425e6e8a9831cd401e85 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 11:04:22 +0200 Subject: [PATCH 44/57] :white_check_mark: Improve tests --- ..._cifar10_datamodule.py => test_cifar10.py} | 0 ...ifar100_datamodule.py => test_cifar100.py} | 0 ...magenet_datamodule.py => test_imagenet.py} | 0 ...test_mnist_datamodule.py => test_mnist.py} | 6 +++++- ...et_datamodule.py => test_tiny_imagenet.py} | 0 ...n_datamodule.py => test_uci_regression.py} | 0 ...stochastic_model.py => test_stochastic.py} | 0 tests/models/wrappers/test_swa.py | 19 ++++++++++++++++--- .../classification/deep_ensembles.py | 3 +-- 9 files changed, 22 insertions(+), 6 deletions(-) rename tests/datamodules/classification/{test_cifar10_datamodule.py => test_cifar10.py} (100%) rename tests/datamodules/classification/{test_cifar100_datamodule.py => test_cifar100.py} (100%) rename tests/datamodules/classification/{test_imagenet_datamodule.py => test_imagenet.py} (100%) rename tests/datamodules/classification/{test_mnist_datamodule.py => test_mnist.py} (91%) rename tests/datamodules/classification/{test_tiny_imagenet_datamodule.py => test_tiny_imagenet.py} (100%) rename tests/datamodules/classification/{test_uci_regression_datamodule.py => test_uci_regression.py} (100%) rename tests/models/wrappers/{test_stochastic_model.py => test_stochastic.py} (100%) diff --git a/tests/datamodules/classification/test_cifar10_datamodule.py b/tests/datamodules/classification/test_cifar10.py similarity index 100% rename from tests/datamodules/classification/test_cifar10_datamodule.py rename to tests/datamodules/classification/test_cifar10.py diff --git a/tests/datamodules/classification/test_cifar100_datamodule.py b/tests/datamodules/classification/test_cifar100.py similarity index 100% rename from tests/datamodules/classification/test_cifar100_datamodule.py rename to tests/datamodules/classification/test_cifar100.py diff --git a/tests/datamodules/classification/test_imagenet_datamodule.py b/tests/datamodules/classification/test_imagenet.py similarity index 100% rename from tests/datamodules/classification/test_imagenet_datamodule.py rename to tests/datamodules/classification/test_imagenet.py diff --git a/tests/datamodules/classification/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist.py similarity index 91% rename from tests/datamodules/classification/test_mnist_datamodule.py rename to tests/datamodules/classification/test_mnist.py index d0517415..f52c9abf 100644 --- a/tests/datamodules/classification/test_mnist_datamodule.py +++ b/tests/datamodules/classification/test_mnist.py @@ -12,7 +12,11 @@ class TestMNISTDataModule: def test_mnist_cutout(self): dm = MNISTDataModule( - root="./data/", batch_size=128, cutout=16, val_split=0.1 + root="./data/", + batch_size=128, + cutout=16, + val_split=0.1, + eval_ood=True, ) assert dm.dataset == MNIST diff --git a/tests/datamodules/classification/test_tiny_imagenet_datamodule.py b/tests/datamodules/classification/test_tiny_imagenet.py similarity index 100% rename from tests/datamodules/classification/test_tiny_imagenet_datamodule.py rename to tests/datamodules/classification/test_tiny_imagenet.py diff --git a/tests/datamodules/classification/test_uci_regression_datamodule.py b/tests/datamodules/classification/test_uci_regression.py similarity index 100% rename from tests/datamodules/classification/test_uci_regression_datamodule.py rename to tests/datamodules/classification/test_uci_regression.py diff --git a/tests/models/wrappers/test_stochastic_model.py b/tests/models/wrappers/test_stochastic.py similarity index 100% rename from tests/models/wrappers/test_stochastic_model.py rename to tests/models/wrappers/test_stochastic.py diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py index 4d3171ef..400f4e60 100644 --- a/tests/models/wrappers/test_swa.py +++ b/tests/models/wrappers/test_swa.py @@ -43,7 +43,9 @@ class TestSWAG: def test_training(self): dl = DataLoader(TensorDataset(torch.randn(1, 1)), batch_size=1) - swag = SWAG(dummy_model(1, 10), cycle_start=1, cycle_length=1) + swag = SWAG( + dummy_model(1, 10), cycle_start=1, cycle_length=1, max_num_models=3 + ) swag.eval() swag(torch.randn(1, 1)) @@ -59,6 +61,8 @@ def test_training(self): swag.update_model(2) swag.update_bn(dl, "cpu") swag(torch.randn(1, 1)) + swag.update_model(3) + swag.update_model(4) swag.eval() swag(torch.randn(1, 1)) @@ -67,10 +71,11 @@ def test_training(self): dummy_model(1, 10), cycle_start=1, cycle_length=1, - diag_covariance=False, + diag_covariance=True, ) swag.train() - swag.update_model(1) + swag.update_model(2) + swag.sample(1, True, False, seed=1) def test_state_dict(self): mod = dummy_model(1, 10) @@ -94,3 +99,11 @@ def test_failures(self): ValueError, match="`var_clamp` must be non-negative. " ): SWAG(nn.Module(), var_clamp=-1, cycle_start=1, cycle_length=1) + swag = SWAG( + nn.Module(), cycle_start=1, cycle_length=1, diag_covariance=True + ) + with pytest.raises( + ValueError, + match="Cannot sample full rank from diagonal covariance matrix.", + ): + swag.sample(scale=1, diag_covariance=False) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index a33480b4..e6188322 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -45,10 +45,9 @@ def __init__( optim_recipe=None, ).eval() models.append(trained_model.model) - print(models) de = deep_ensembles(models=models) - super().__init__( + super().__init__( # coverage: ignore num_classes=num_classes, model=de, loss=None, From 676f272351fd6a74a80c20a6a16b9c4943316185 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 12:00:28 +0200 Subject: [PATCH 45/57] :shirt: Improve ReadMe --- README.md | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 072439c1..a66c6b85 100644 --- a/README.md +++ b/README.md @@ -24,15 +24,15 @@ TorchUncertainty contains the *official implementations* of multiple papers from This package provides a multi-level API, including: -- easy-to-use ⚡️ lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. +- easy-to-use :zap: lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧). +- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (:construction: work in progress :construction:). - **layers**, **models**, **metrics**, & **losses** available for use in your networks - scikit-learn style post-processing methods such as Temperature Scaling. Have a look at the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. -## ⚙️ Installation +## :gear: Installation TorchUncertainty requires Python 3.10 or greater. Install the desired PyTorch version in your environment. Then, install the package from PyPI: @@ -51,7 +51,6 @@ We make a quickstart available at [torch-uncertainty.github.io/quickstart](https TorchUncertainty currently supports **classification**, **probabilistic** and pointwise **regression**, **segmentation** and **pixelwise regression** (such as monocular depth estimation). It includes the official codes of the following papers: -- *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) - *LP-BNN: Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification* - [IEEE TPAMI](https://arxiv.org/abs/2012.02818) - *Packed-Ensembles for Efficient Uncertainty Estimation* - [ICLR 2023](https://arxiv.org/abs/2210.09184) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - *MUAD: Multiple Uncertainties for Autonomous Driving, a benchmark for multiple uncertainty types and tasks* - [BMVC 2022](https://arxiv.org/abs/2203.01437) @@ -60,19 +59,16 @@ We also provide the following methods: ### Baselines -To date, the following deep learning baselines have been implemented: +To date, the following deep learning baselines have been implemented. **Click on the methods for tutorials**: -- Deep Ensembles -- MC-Dropout - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) -- BatchEnsemble -- Masksembles -- MIMO -- Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) -- Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) +- [Deep Ensembles](https://torch-uncertainty.github.io/auto_tutorials/tutorial_from_de_to_pe.html), BatchEnsemble, Masksembles, & MIMO +- [MC-Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) +- [Packed-Ensembles](https://torch-uncertainty.github.io/auto_tutorials/tutorial_from_de_to_pe.html) (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) +- [Variational Bayesian Neural Networks](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) - Checkpoint Ensembles & Snapshot Ensembles - Stochastic Weight Averaging & Stochastic Weight Averaging Gaussian - Regression with Beta Gaussian NLL Loss -- Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) +- [Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) & [Regression](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html) ### Augmentation methods @@ -84,17 +80,18 @@ The following data augmentation methods have been implemented: To date, the following post-processing methods have been implemented: -- Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) -- Monte Carlo Batch Normalization - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html) -- A wrapper for Laplace appoximation using the [Laplace library](https://github.com/aleximmer/Laplace) +- [Temperature](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html), Vector, & Matrix scaling +- [Monte Carlo Batch Normalization](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html) +- Laplace approximation using the [Laplace library](https://github.com/aleximmer/Laplace) ## Tutorials -Our documentation contains the following tutorials: +Check out our tutorials at [torch-uncertainty.github.io/auto_tutorials](https://torch-uncertainty.github.io/auto_tutorials/index.html). + +## :telescope: Projects using TorchUncertainty + +The following projects use TorchUncertainty: + +- *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) -- [From a Standard Classifier to a Packed-Ensemble](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) -- [Training a Bayesian Neural Network in 3 minutes](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) -- [Improve Top-label Calibration with Temperature Scaling](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) -- [Deep Evidential Regression on a Toy Example](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html) -- [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) -- [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) +**If you are using TorchUncertainty in your project, please let us know, we will add your project to this list!** From 80aaaa848307e3ed27cdbb5a06b914f53aac6be2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 12:14:03 +0200 Subject: [PATCH 46/57] :bug: Fix SWAG --- tests/models/wrappers/test_swa.py | 30 +++++++++++++++++++++-- torch_uncertainty/models/wrappers/swag.py | 21 ++++++++-------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py index 400f4e60..2e693df2 100644 --- a/tests/models/wrappers/test_swa.py +++ b/tests/models/wrappers/test_swa.py @@ -44,26 +44,52 @@ class TestSWAG: def test_training(self): dl = DataLoader(TensorDataset(torch.randn(1, 1)), batch_size=1) swag = SWAG( - dummy_model(1, 10), cycle_start=1, cycle_length=1, max_num_models=3 + dummy_model(1, 10), + cycle_start=1, + cycle_length=1, + max_num_models=3, + num_estimators=2, ) + assert swag.num_avgd_models == 0 swag.eval() swag(torch.randn(1, 1)) swag.train() swag(torch.randn(1, 1)) swag.update_model(0) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (0, 10) swag.update_bn(dl, "cpu") swag(torch.randn(1, 1)) swag.update_model(1) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (0, 10) + assert swag.num_avgd_models == 0 swag.update_bn(dl, "cpu") swag.update_model(2) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (1, 10) swag.update_bn(dl, "cpu") swag(torch.randn(1, 1)) swag.update_model(3) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (2, 10) swag.update_model(4) - + assert swag.num_avgd_models == 3 + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (3, 10) + swag.update_model(5) + assert swag.num_avgd_models == 4 + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (3, 10) swag.eval() swag(torch.randn(1, 1)) diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 3c053135..7670979a 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -54,7 +54,6 @@ def __init__( super().__init__(model, cycle_start, cycle_length) _swag_checks(scale, max_num_models, var_clamp) - self.num_models = 0 self.num_estimators = num_estimators self.scale = scale @@ -110,12 +109,12 @@ def update_model(self, epoch: int) -> None: squared_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] new_param = param.data.detach().cpu() - mean = mean * self.num_models / ( - self.num_models + 1 - ) + new_param / (self.num_models + 1) - squared_mean = squared_mean * self.num_models / ( - self.num_models + 1 - ) + new_param**2 / (self.num_models + 1) + mean = mean * self.num_avgd_models / ( + self.num_avgd_models + 1 + ) + new_param / (self.num_avgd_models + 1) + squared_mean = squared_mean * self.num_avgd_models / ( + self.num_avgd_models + 1 + ) + new_param**2 / (self.num_avgd_models + 1) self.swag_stats[self.prfx + name_p + "_mean"] = mean self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean @@ -126,13 +125,13 @@ def update_model(self, epoch: int) -> None: ] dev = (new_param - mean).view(-1, 1).t() covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) - if self.num_models + 1 > self.max_num_models: + if self.num_avgd_models + 1 > self.max_num_models: covariance_sqrt = covariance_sqrt[1:, :] self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( covariance_sqrt ) - self.num_models += 1 + self.num_avgd_models += 1 self.samples = [ self.sample(self.scale, self.diag_covariance) @@ -246,10 +245,10 @@ def load_state_dict( return super().load_state_dict(state_dict, strict, assign) def compute_logdet(self, block=False): - raise NotImplementedError("Raise an issue if you need this feature") + raise NotImplementedError("Raise an issue if you need this feature.") def compute_logprob(self, vec=None, block=False, diag=False): - raise NotImplementedError("Raise an issue if you need this feature") + raise NotImplementedError("Raise an issue if you need this feature.") def _swag_checks(scale: float, max_num_models: int, var_clamp: float) -> None: From 06b990f445e425122c28c8e29ba64e214124dec8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 14:00:48 +0200 Subject: [PATCH 47/57] :sparkles: Propagate changes to the other routines & update tests --- .../mnist/configs/bayesian_lenet.yaml | 59 ++++++++++++++ .../mnist/configs/lenet_swamerged.yaml | 73 ++++++++++++++++++ .../camvid/configs/segformer.yaml | 1 - .../cityscapes/configs/deeplab.yaml | 1 - .../cityscapes/configs/segformer.yaml | 1 - .../segmentation/muad/configs/segformer.yaml | 1 - tests/_dummies/__init__.py | 6 +- tests/_dummies/baseline.py | 16 ++-- tests/_dummies/datamodule.py | 6 +- tests/_dummies/dataset.py | 2 +- tests/datamodules/test_depth.py | 6 +- ...test_depth.py => test_pixel_regression.py} | 28 +++---- tests/routines/test_regression.py | 12 ++- tests/routines/test_segmentation.py | 10 --- torch_uncertainty/baselines/regression/mlp.py | 2 +- .../baselines/segmentation/deeplab.py | 2 - .../baselines/segmentation/segformer.py | 2 - .../routines/pixel_regression.py | 76 +++++++++++-------- torch_uncertainty/routines/regression.py | 67 ++++++++++------ torch_uncertainty/routines/segmentation.py | 71 ++++++++++------- torch_uncertainty/utils/distributions.py | 11 ++- 21 files changed, 305 insertions(+), 148 deletions(-) create mode 100644 experiments/classification/mnist/configs/bayesian_lenet.yaml create mode 100644 experiments/classification/mnist/configs/lenet_swamerged.yaml rename tests/routines/{test_depth.py => test_pixel_regression.py} (73%) diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml new file mode 100644 index 00000000..0c7989ab --- /dev/null +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -0,0 +1,59 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_swamerged.yaml b/experiments/classification/mnist/configs/lenet_swamerged.yaml new file mode 100644 index 00000000..29a06780 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_swamerged.yaml @@ -0,0 +1,73 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_swa + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.SWA + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + cycle_start: 19 + cycle_length: 5 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.SequentialLR + init_args: + milestones: + - 20 + schedulers: + - class_path: torch.optim.lr_scheduler.StepLR + init_args: + step_size: 5 + gamma: 0.1 + - class_path: torch.optim.lr_scheduler.StepLR + init_args: + step_size: 5 + gamma: 0.1 + diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 7cbb001b..0dfac0a0 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -9,7 +9,6 @@ model: loss: CrossEntropyLoss version: std arch: 0 - num_estimators: 1 data: root: ./data batch_size: 16 diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml index babefa1c..a76408b6 100644 --- a/experiments/segmentation/cityscapes/configs/deeplab.yaml +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -29,7 +29,6 @@ model: style: v3+ output_stride: 16 separable: false - num_estimators: 1 data: root: ./data/Cityscapes batch_size: 8 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index 0ae0c212..edf7e8c2 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -25,7 +25,6 @@ model: loss: CrossEntropyLoss version: std arch: 0 - num_estimators: 1 data: root: ./data/Cityscapes batch_size: 8 diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml index b2abf11e..b26dbf90 100644 --- a/experiments/segmentation/muad/configs/segformer.yaml +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -10,7 +10,6 @@ model: loss: CrossEntropyLoss version: std arch: 0 - num_estimators: 1 data: root: ./data batch_size: 8 diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index d942a5ae..8dec1d03 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -1,19 +1,19 @@ # ruff: noqa: F401 from .baseline import ( DummyClassificationBaseline, - DummyDepthBaseline, + DummyPixelRegressionBaseline, DummyRegressionBaseline, DummySegmentationBaseline, ) from .datamodule import ( + DummPixelRegressionDataModule, DummyClassificationDataModule, - DummyDepthDataModule, DummyRegressionDataModule, DummySegmentationDataModule, ) from .dataset import ( + DummPixelRegressionDataset, DummyClassificationDataset, - DummyDepthDataset, DummyRegressionDataset, DummySegmentationDataset, ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index c87c71d5..e6350406 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -112,7 +112,7 @@ def __new__( probabilistic: bool, in_features: int, output_dim: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", optim_recipe=None, dist_type: str = "normal", @@ -142,7 +142,6 @@ def __new__( output_dim=output_dim, model=model, loss=loss, - num_estimators=1, optim_recipe=optim_recipe(model), ) # baseline_type == "ensemble": @@ -156,7 +155,7 @@ def __new__( output_dim=output_dim, model=model, loss=loss, - num_estimators=2, + is_ensemble=True, optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), ) @@ -168,7 +167,7 @@ def __new__( in_channels: int, num_classes: int, image_size: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", optim_recipe=None, metric_subsampling_rate: float = 1, @@ -186,7 +185,6 @@ def __new__( model=model, loss=loss, format_batch_fn=None, - num_estimators=1, optim_recipe=optim_recipe(model), metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, @@ -202,20 +200,19 @@ def __new__( model=model, loss=loss, format_batch_fn=RepeatTarget(2), - num_estimators=2, optim_recipe=optim_recipe(model), metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, ) -class DummyDepthBaseline: +class DummyPixelRegressionBaseline: def __new__( cls, in_channels: int, output_dim: int, image_size: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", optim_recipe=None, ) -> PixelRegressionRoutine: @@ -232,7 +229,6 @@ def __new__( model=model, loss=loss, format_batch_fn=None, - num_estimators=1, optim_recipe=optim_recipe(model), ) @@ -248,6 +244,6 @@ def __new__( model=model, loss=loss, format_batch_fn=RepeatTarget(2), - num_estimators=2, + is_ensemble=True, optim_recipe=optim_recipe(model), ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 3de1850a..7e34a92b 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -10,8 +10,8 @@ from torch_uncertainty.datamodules.abstract import BaseDataModule from .dataset import ( + DummPixelRegressionDataset, DummyClassificationDataset, - DummyDepthDataset, DummyRegressionDataset, DummySegmentationDataset, ) @@ -249,7 +249,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyDepthDataModule(BaseDataModule): +class DummPixelRegressionDataModule(BaseDataModule): num_channels = 3 training_task = "pixel_regression" @@ -278,7 +278,7 @@ def __init__( self.num_images = num_images self.image_size = image_size - self.dataset = DummyDepthDataset + self.dataset = DummPixelRegressionDataset self.train_transform = T.ToDtype( dtype={ diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 1ab0c66b..662e4f9f 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -214,7 +214,7 @@ def __len__(self) -> int: return len(self.data) -class DummyDepthDataset(Dataset): +class DummPixelRegressionDataset(Dataset): def __init__( self, root: Path, diff --git a/tests/datamodules/test_depth.py b/tests/datamodules/test_depth.py index bee19975..4305be05 100644 --- a/tests/datamodules/test_depth.py +++ b/tests/datamodules/test_depth.py @@ -1,6 +1,6 @@ import pytest -from tests._dummies.dataset import DummyDepthDataset +from tests._dummies.dataset import DummPixelRegressionDataset from torch_uncertainty.datamodules.depth import ( KITTIDataModule, MUADDataModule, @@ -19,7 +19,7 @@ def test_muad_main(self): assert dm.dataset == MUAD - dm.dataset = DummyDepthDataset + dm.dataset = DummPixelRegressionDataset dm.prepare_data() dm.setup() @@ -51,7 +51,7 @@ def test_nyu_main(self): assert dm.dataset == NYUv2 - dm.dataset = DummyDepthDataset + dm.dataset = DummPixelRegressionDataset dm.prepare_data() dm.setup() diff --git a/tests/routines/test_depth.py b/tests/routines/test_pixel_regression.py similarity index 73% rename from tests/routines/test_depth.py rename to tests/routines/test_pixel_regression.py index e404ca80..dfda81a4 100644 --- a/tests/routines/test_depth.py +++ b/tests/routines/test_pixel_regression.py @@ -4,8 +4,8 @@ from torch import nn from tests._dummies import ( - DummyDepthBaseline, - DummyDepthDataModule, + DummPixelRegressionDataModule, + DummyPixelRegressionBaseline, ) from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import PixelRegressionRoutine @@ -17,9 +17,11 @@ def test_one_estimator_two_classes(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummyDepthDataModule(root=root, batch_size=4, output_dim=2) + dm = DummPixelRegressionDataModule( + root=root, batch_size=4, output_dim=2 + ) - model = DummyDepthBaseline( + model = DummyPixelRegressionBaseline( in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, @@ -37,9 +39,11 @@ def test_two_estimators_one_class(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummyDepthDataModule(root=root, batch_size=4, output_dim=1) + dm = DummPixelRegressionDataModule( + root=root, batch_size=4, output_dim=1 + ) - model = DummyDepthBaseline( + model = DummyPixelRegressionBaseline( in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, @@ -54,22 +58,10 @@ def test_two_estimators_one_class(self): model(dm.get_test_set()[0][0]) def test_depth_errors(self): - with pytest.raises( - ValueError, match="num_estimators must be positive, got" - ): - PixelRegressionRoutine( - model=nn.Identity(), - output_dim=2, - loss=nn.MSELoss(), - num_estimators=0, - probabilistic=False, - ) - with pytest.raises(ValueError, match="output_dim must be positive"): PixelRegressionRoutine( model=nn.Identity(), output_dim=0, loss=nn.MSELoss(), - num_estimators=1, probabilistic=False, ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 2c7eb469..37740273 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -142,12 +142,10 @@ def test_two_estimators_two_outputs(self): model(dm.get_test_set()[0][0]) def test_regression_failures(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="output_dim must be positive"): RegressionRoutine( - True, 1, nn.Identity(), nn.MSELoss, num_estimators=0 - ) - - with pytest.raises(ValueError): - RegressionRoutine( - True, 0, nn.Identity(), nn.MSELoss, num_estimators=1 + probabilistic=True, + output_dim=0, + model=nn.Identity(), + loss=nn.MSELoss(), ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index cb2e41a3..172934b0 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -55,16 +55,6 @@ def test_two_estimators_two_classes(self): model(dm.get_test_set()[0][0]) def test_segmentation_errors(self): - with pytest.raises( - ValueError, match="num_estimators must be positive, got" - ): - SegmentationRoutine( - model=nn.Identity(), - num_classes=2, - loss=nn.CrossEntropyLoss(), - num_estimators=0, - ) - with pytest.raises( ValueError, match="num_classes must be at least 2, got" ): diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 02e3c658..7f5bb975 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -82,7 +82,7 @@ def __init__( output_dim=output_dim, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in self.ensemble, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/deeplab.py b/torch_uncertainty/baselines/segmentation/deeplab.py index 01575f1f..f3e3982c 100644 --- a/torch_uncertainty/baselines/segmentation/deeplab.py +++ b/torch_uncertainty/baselines/segmentation/deeplab.py @@ -28,7 +28,6 @@ def __init__( style: Literal["v3", "v3+"], output_stride: int, separable: bool, - num_estimators: int = 1, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, num_calibration_bins: int = 15, @@ -52,7 +51,6 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, format_batch_fn=format_batch_fn, metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 97d98a3b..c2a46013 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -21,7 +21,6 @@ def __init__( loss: nn.Module, version: Literal["std"], arch: int, - num_estimators: int = 1, ) -> None: r"""SegFormer backbone baseline for segmentation providing support for various versions and architectures. @@ -63,7 +62,6 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index d729e11b..b5a398e1 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -29,6 +29,10 @@ SILog, ThresholdAccuracy, ) +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist @@ -44,21 +48,24 @@ def __init__( output_dim: int, probabilistic: bool, loss: nn.Module, - num_estimators: int = 1, + is_ensemble: bool = False, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, num_image_plot: int = 4, ) -> None: super().__init__() - _depth_routine_checks(num_estimators, output_dim) + _depth_routine_checks(output_dim) self.model = model self.output_dim = output_dim self.one_dim_depth = output_dim == 1 self.probabilistic = probabilistic self.loss = loss - self.num_estimators = num_estimators self.num_image_plot = num_image_plot + self.is_ensemble = is_ensemble + + self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -102,6 +109,20 @@ def on_train_start(self) -> None: self.hparams, ) + def on_validation_start(self) -> None: + if self.need_epoch_update and not self.trainer.sanity_checking: + self.model.update_model(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + + def on_test_start(self) -> None: + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. @@ -116,10 +137,10 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: """ pred = self.model(inputs) if self.probabilistic: - if self.num_estimators == 1: + if not self.is_ensemble: pred = squeeze_dist(pred, -1) else: - if self.num_estimators == 1: + if not self.is_ensemble: pred = pred.squeeze(-1) return pred @@ -142,26 +163,26 @@ def training_step( def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - inputs, target = batch + inputs, targets = batch if self.one_dim_depth: - target = target.unsqueeze(1) + targets = targets.unsqueeze(1) preds = self.model(inputs) if self.probabilistic: ens_dist = Independent( dist_rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators + preds, "(m b) c h w -> b m c h w", b=targets.size(0) ), 1, ) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones(preds.size(0) // targets.size(0), device=self.device) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: preds = rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators + preds, "(m b) c h w -> b m c h w", b=targets.size(0) ) preds = preds.mean(dim=1) @@ -169,15 +190,15 @@ def validation_step( self._plot_depth( inputs[: self.num_image_plot, ...], preds[: self.num_image_plot, ...], - target[: self.num_image_plot, ...], + targets[: self.num_image_plot, ...], stage="val", ) - valid_mask = ~torch.isnan(target) - self.val_metrics.update(preds[valid_mask], target[valid_mask]) + valid_mask = ~torch.isnan(targets) + self.val_metrics.update(preds[valid_mask], targets[valid_mask]) if self.probabilistic: self.val_prob_metrics.update( - mixture[valid_mask], target[valid_mask] + mixture[valid_mask], targets[valid_mask] ) def test_step( @@ -192,24 +213,24 @@ def test_step( "if needed." ) - inputs, target = batch + inputs, targets = batch if self.one_dim_depth: - target = target.unsqueeze(1) + targets = targets.unsqueeze(1) preds = self.model(inputs) if self.probabilistic: ens_dist = dist_rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators + preds, "(m b) c h w -> b m c h w", b=targets.size(0) ) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones(preds.size(0) // targets.size(0), device=self.device) ) mixture = MixtureSameFamily(mix, ens_dist) - self.test_metrics.nll.update(mixture, target) + self.test_metrics.nll.update(mixture, targets) preds = mixture.mean else: preds = rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators + preds, "(m b) c h w -> b m c h w", b=targets.size(0) ) preds = preds.mean(dim=1) @@ -222,15 +243,15 @@ def test_step( self._plot_depth( inputs[:num_images, ...], preds[:num_images, ...], - target[:num_images, ...], + targets[:num_images, ...], stage="test", ) - valid_mask = ~torch.isnan(target) - self.test_metrics.update(preds[valid_mask], target[valid_mask]) + valid_mask = ~torch.isnan(targets) + self.test_metrics.update(preds[valid_mask], targets[valid_mask]) if self.probabilistic: self.test_prob_metrics.update( - mixture[valid_mask], target[valid_mask] + mixture[valid_mask], targets[valid_mask] ) def on_validation_epoch_end(self) -> None: @@ -311,11 +332,6 @@ def colorize( return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 -def _depth_routine_checks(num_estimators: int, output_dim: int) -> None: - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - +def _depth_routine_checks(output_dim: int) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 55998518..dbbb06ed 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -15,7 +15,15 @@ from torch_uncertainty.metrics import ( DistributionNLL, ) -from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) +from torch_uncertainty.utils.distributions import ( + dist_rearrange, + size_dist, + squeeze_dist, +) class RegressionRoutine(LightningModule): @@ -25,7 +33,7 @@ def __init__( output_dim: int, probabilistic: bool, loss: nn.Module, - num_estimators: int = 1, + is_ensemble: bool = False, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: @@ -37,8 +45,8 @@ def __init__( probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. - num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to ``1`` (single model). + is_ensemble (bool, optional): Whether the model is an ensemble. + Defaults to ``False``. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the @@ -58,13 +66,15 @@ def __init__( `here `_. """ super().__init__() - _regression_routine_checks(num_estimators, output_dim) + _regression_routine_checks(output_dim) self.model = model self.probabilistic = probabilistic self.output_dim = output_dim self.loss = loss - self.num_estimators = num_estimators + self.is_ensemble = is_ensemble + self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -102,6 +112,20 @@ def on_train_start(self) -> None: self.hparams, ) + def on_validation_start(self) -> None: + if self.need_epoch_update and not self.trainer.sanity_checking: + self.model.update_model(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + + def on_test_start(self) -> None: + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. @@ -118,12 +142,12 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: if self.probabilistic: if self.one_dim_regression: pred = squeeze_dist(pred, -1) - if self.num_estimators == 1: + if not self.is_ensemble: pred = squeeze_dist(pred, -1) else: if self.one_dim_regression: pred = pred.squeeze(-1) - if self.num_estimators == 1: + if not self.is_ensemble: pred = pred.squeeze(-1) return pred @@ -150,18 +174,18 @@ def validation_step( if self.probabilistic: ens_dist = Independent( - dist_rearrange( - preds, "(m b) c -> b m c", m=self.num_estimators - ), + dist_rearrange(preds, "(m b) c -> b m c", b=targets.size(0)), 1, ) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones( + size_dist(preds)[0] // targets.size(0), device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = rearrange(preds, "(m b) c -> b m c", b=targets.size(0)) preds = preds.mean(dim=1) self.val_metrics.update(preds, targets) @@ -187,18 +211,18 @@ def test_step( if self.probabilistic: ens_dist = Independent( - dist_rearrange( - preds, "(m b) c -> b m c", m=self.num_estimators - ), + dist_rearrange(preds, "(m b) c -> b m c", b=targets.size(0)), 1, ) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones( + size_dist(preds)[0] // targets.size(0), device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = rearrange(preds, "(m b) c -> b m c", b=targets.size(0)) preds = preds.mean(dim=1) self.test_metrics.update(preds, targets) @@ -225,11 +249,6 @@ def on_test_epoch_end(self) -> None: self.test_prob_metrics.reset() -def _regression_routine_checks(num_estimators: int, output_dim: int) -> None: - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - +def _regression_routine_checks(output_dim: int) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index b372a04d..a5d89a4b 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -15,6 +15,10 @@ CategoricalNLL, MeanIntersectionOverUnion, ) +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) class SegmentationRoutine(LightningModule): @@ -23,7 +27,6 @@ def __init__( model: nn.Module, num_classes: int, loss: nn.Module, - num_estimators: int = 1, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, metric_subsampling_rate: float = 1e-2, @@ -36,8 +39,6 @@ def __init__( model (torch.nn.Module): Model to train. num_classes (int): Number of classes in the segmentation task. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. - num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to ``1`` (single model). optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the @@ -59,7 +60,6 @@ def __init__( """ super().__init__() _segmentation_routine_checks( - num_estimators, num_classes, metric_subsampling_rate, num_calibration_bins, @@ -68,7 +68,8 @@ def __init__( self.model = model self.num_classes = num_classes self.loss = loss - self.num_estimators = num_estimators + self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -131,6 +132,20 @@ def on_train_start(self) -> None: if self.logger is not None: # coverage: ignore self.logger.log_hyperparams(self.hparams) + def on_validation_start(self) -> None: + if self.need_epoch_update and not self.trainer.sanity_checking: + self.model.update_model(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + + def on_test_start(self) -> None: + if hasattr(self.model, "need_bn_update"): + self.model.update_bn( + self.trainer.train_dataloader, device=self.device + ) + def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: @@ -150,38 +165,42 @@ def training_step( def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - img, target = batch + img, targets = batch logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + targets = F.resize( + targets, + logits.shape[-2:], + interpolation=F.InterpolationMode.NEAREST, ) logits = rearrange( - logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + logits, "(m b) c h w -> (b h w) m c", b=targets.size(0) ) probs_per_est = logits.softmax(dim=-1) probs = probs_per_est.mean(dim=1) - target = target.flatten() - valid_mask = target != 255 - probs, target = probs[valid_mask], target[valid_mask] - self.val_seg_metrics.update(probs, target) - self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, target)) + targets = targets.flatten() + valid_mask = targets != 255 + probs, targets = probs[valid_mask], targets[valid_mask] + self.val_seg_metrics.update(probs, targets) + self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: - img, target = batch + img, targets = batch logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + targets = F.resize( + targets, + logits.shape[-2:], + interpolation=F.InterpolationMode.NEAREST, ) logits = rearrange( - logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + logits, "(m b) c h w -> (b h w) m c", b=targets.size(0) ) probs_per_est = logits.softmax(dim=-1) probs = probs_per_est.mean(dim=1) - target = target.flatten() - valid_mask = target != 255 - probs, target = probs[valid_mask], target[valid_mask] - self.test_seg_metrics.update(probs, target) - self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, target)) + targets = targets.flatten() + valid_mask = targets != 255 + probs, targets = probs[valid_mask], targets[valid_mask] + self.test_seg_metrics.update(probs, targets) + self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_seg_metrics.compute(), sync_dist=True) @@ -210,16 +229,10 @@ def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: def _segmentation_routine_checks( - num_estimators: int, num_classes: int, metric_subsampling_rate: float, num_calibration_bins: int, ) -> None: - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - if num_classes < 2: raise ValueError(f"num_classes must be at least 2, got {num_classes}.") diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 58f8532b..aa3d1043 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -7,8 +7,17 @@ from torch.distributions.utils import broadcast_all +def size_dist(distribution: Distribution) -> torch.Size: + if isinstance(distribution, Normal | Laplace | NormalInverseGamma): + return distribution.loc.size() + raise NotImplementedError( + f"Size of {type(distribution)} distributions is not supported." + "Raise an issue if needed." + ) + + def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: - r"""Concatenate a list of distributions into a single distribution. + """Concatenate a list of distributions into a single distribution. Args: distributions (list[Distribution]): The list of distributions. From 725bd9cc6b0cfbb56438d390cffaf3ac3a3e1c0f Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 14:55:09 +0200 Subject: [PATCH 48/57] :hammer: rename inference_size to eval_size --- experiments/depth/kitti/configs/bts.yaml | 2 +- experiments/depth/nyu/configs/bts.yaml | 2 +- .../segmentation/cityscapes/configs/deeplab.yaml | 2 +- .../segmentation/cityscapes/configs/segformer.yaml | 2 +- experiments/segmentation/muad/configs/segformer.yaml | 2 +- torch_uncertainty/datamodules/depth/base.py | 8 ++++---- torch_uncertainty/datamodules/depth/kitti.py | 6 +++--- torch_uncertainty/datamodules/depth/muad.py | 6 +++--- torch_uncertainty/datamodules/depth/nyu.py | 6 +++--- .../datamodules/segmentation/cityscapes.py | 10 +++++----- torch_uncertainty/datamodules/segmentation/muad.py | 10 +++++----- 11 files changed, 28 insertions(+), 28 deletions(-) diff --git a/experiments/depth/kitti/configs/bts.yaml b/experiments/depth/kitti/configs/bts.yaml index 89de3232..3c20e048 100644 --- a/experiments/depth/kitti/configs/bts.yaml +++ b/experiments/depth/kitti/configs/bts.yaml @@ -37,7 +37,7 @@ data: crop_size: - 352 - 704 - inference_size: + eval_size: - 352 - 1216 num_workers: 4 diff --git a/experiments/depth/nyu/configs/bts.yaml b/experiments/depth/nyu/configs/bts.yaml index 8a9d0957..48f9d7db 100644 --- a/experiments/depth/nyu/configs/bts.yaml +++ b/experiments/depth/nyu/configs/bts.yaml @@ -37,7 +37,7 @@ data: crop_size: - 416 - 544 - inference_size: + eval_size: - 480 - 640 num_workers: 8 diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml index a76408b6..51cc2a1e 100644 --- a/experiments/segmentation/cityscapes/configs/deeplab.yaml +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -33,7 +33,7 @@ data: root: ./data/Cityscapes batch_size: 8 crop_size: 768 - inference_size: + eval_size: - 1024 - 2048 num_workers: 8 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index edf7e8c2..145a96eb 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -29,7 +29,7 @@ data: root: ./data/Cityscapes batch_size: 8 crop_size: 1024 - inference_size: + eval_size: - 1024 - 2048 num_workers: 8 diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml index b26dbf90..a0c110e0 100644 --- a/experiments/segmentation/muad/configs/segformer.yaml +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -14,7 +14,7 @@ data: root: ./data batch_size: 8 crop_size: 1024 - inference_size: + eval_size: - 1024 - 2048 num_workers: 30 diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 2337e8e2..2b38405d 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -21,7 +21,7 @@ def __init__( min_depth: float, max_depth: float, crop_size: _size_2_t, - inference_size: _size_2_t, + eval_size: _size_2_t, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -42,7 +42,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -66,7 +66,7 @@ def __init__( self.min_depth = min_depth self.max_depth = max_depth self.crop_size = _pair(crop_size) - self.inference_size = _pair(inference_size) + self.eval_size = _pair(eval_size) self.train_transform = v2.Compose( [ @@ -91,7 +91,7 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.Resize(size=self.inference_size), + v2.Resize(size=self.eval_size), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, diff --git a/torch_uncertainty/datamodules/depth/kitti.py b/torch_uncertainty/datamodules/depth/kitti.py index 55f30296..c5035893 100644 --- a/torch_uncertainty/datamodules/depth/kitti.py +++ b/torch_uncertainty/datamodules/depth/kitti.py @@ -15,7 +15,7 @@ def __init__( min_depth: float = 1e-3, max_depth: float = 80.0, crop_size: _size_2_t = (352, 704), - inference_size: _size_2_t = (375, 1242), + eval_size: _size_2_t = (375, 1242), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -36,7 +36,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(375, 1242)``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -58,7 +58,7 @@ def __init__( min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, - inference_size=inference_size, + eval_size=eval_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index 5ca8643b..cf4f6cde 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -16,7 +16,7 @@ def __init__( min_depth: float, max_depth: float, crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -36,7 +36,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -58,7 +58,7 @@ def __init__( min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, - inference_size=inference_size, + eval_size=eval_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/nyu.py b/torch_uncertainty/datamodules/depth/nyu.py index c421c044..ec925ffa 100644 --- a/torch_uncertainty/datamodules/depth/nyu.py +++ b/torch_uncertainty/datamodules/depth/nyu.py @@ -15,7 +15,7 @@ def __init__( min_depth: float = 1e-3, max_depth: float = 10.0, crop_size: _size_2_t = (416, 544), - inference_size: _size_2_t = (480, 640), + eval_size: _size_2_t = (480, 640), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -36,7 +36,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(416, 544)``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -58,7 +58,7 @@ def __init__( min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, - inference_size=inference_size, + eval_size=eval_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index c6800112..03a923a1 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -18,7 +18,7 @@ def __init__( root: str | Path, batch_size: int, crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -35,7 +35,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -84,7 +84,7 @@ def __init__( v2.Compose([ v2.ToImage(), - v2.Resize(size=inference_size, antialias=True), + v2.Resize(size=eval_size, antialias=True), v2.ToDtype({ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -109,7 +109,7 @@ def __init__( self.dataset = Cityscapes self.mode = "fine" self.crop_size = _pair(crop_size) - self.inference_size = _pair(inference_size) + self.eval_size = _pair(eval_size) self.train_transform = v2.Compose( [ @@ -138,7 +138,7 @@ def __init__( self.test_transform = v2.Compose( [ v2.ToImage(), - v2.Resize(size=self.inference_size, antialias=True), + v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 37949b1a..43b5e44c 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -18,7 +18,7 @@ def __init__( root: str | Path, batch_size: int, crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -35,7 +35,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -84,7 +84,7 @@ def __init__( v2.Compose([ v2.ToImage(), - v2.Resize(size=inference_size, antialias=True), + v2.Resize(size=eval_size, antialias=True), v2.ToDtype({ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -108,7 +108,7 @@ def __init__( self.dataset = MUAD self.crop_size = _pair(crop_size) - self.inference_size = _pair(inference_size) + self.eval_size = _pair(eval_size) self.train_transform = v2.Compose( [ @@ -135,7 +135,7 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.Resize(size=self.inference_size, antialias=True), + v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, From fcbfeaaff85b038f9cc5952e435dac5b57290d65 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 20:17:45 +0200 Subject: [PATCH 49/57] :bug: Fix regression routines --- tests/_dummies/baseline.py | 45 ++++++++++-- tests/_dummies/model.py | 33 +++++---- tests/routines/test_pixel_regression.py | 29 ++++++-- tests/routines/test_regression.py | 18 +++-- tests/routines/test_segmentation.py | 2 + torch_uncertainty/losses.py | 12 +++- torch_uncertainty/metrics/regression/nll.py | 16 ++++- torch_uncertainty/optim_recipes.py | 37 ++++++---- .../post_processing/calibration/scaler.py | 4 +- .../calibration/temperature_scaler.py | 4 +- torch_uncertainty/routines/classification.py | 2 +- .../routines/pixel_regression.py | 68 +++++++++++-------- torch_uncertainty/routines/regression.py | 34 +++++----- torch_uncertainty/routines/segmentation.py | 2 + torch_uncertainty/utils/distributions.py | 39 +++++++---- 15 files changed, 239 insertions(+), 106 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index e6350406..3c39376b 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -116,6 +116,8 @@ def __new__( baseline_type: str = "single", optim_recipe=None, dist_type: str = "normal", + ema: bool = False, + swa: bool = False, ) -> RegressionRoutine: if probabilistic: if dist_type == "normal": @@ -136,6 +138,11 @@ def __new__( num_classes=num_classes, last_layer=last_layer, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) + if baseline_type == "single": return RegressionRoutine( probabilistic=probabilistic, @@ -172,12 +179,18 @@ def __new__( optim_recipe=None, metric_subsampling_rate: float = 1, log_plots: bool = False, + ema: bool = False, + swa: bool = False, ) -> SegmentationRoutine: model = dummy_segmentation_model( in_channels=in_channels, num_classes=num_classes, image_size=image_size, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) if baseline_type == "single": return SegmentationRoutine( @@ -209,26 +222,48 @@ def __new__( class DummyPixelRegressionBaseline: def __new__( cls, + probabilistic: bool, in_channels: int, output_dim: int, image_size: int, loss: nn.Module, + dist_type: str = "normal", baseline_type: str = "single", optim_recipe=None, + ema: bool = False, + swa: bool = False, ) -> PixelRegressionRoutine: + if probabilistic: + if dist_type == "normal": + last_layer = NormalLayer(output_dim) + num_classes = output_dim * 2 + elif dist_type == "laplace": + last_layer = LaplaceLayer(output_dim) + num_classes = output_dim * 2 + else: # dist_type == "nig" + last_layer = NormalInverseGammaLayer(output_dim) + num_classes = output_dim * 4 + else: + last_layer = nn.Identity() + num_classes = output_dim + model = dummy_segmentation_model( - num_classes=output_dim, + num_classes=num_classes, in_channels=in_channels, image_size=image_size, + last_layer=last_layer, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) if baseline_type == "single": return PixelRegressionRoutine( + probabilistic=probabilistic, output_dim=output_dim, - probabilistic=False, model=model, loss=loss, - format_batch_fn=None, optim_recipe=optim_recipe(model), ) @@ -236,11 +271,11 @@ def __new__( model = deep_ensembles( [model, copy.deepcopy(model)], task="pixel_regression", - probabilistic=False, + probabilistic=probabilistic, ) return PixelRegressionRoutine( + probabilistic=probabilistic, output_dim=output_dim, - probabilistic=False, model=model, loss=loss, format_batch_fn=RepeatTarget(2), diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 2c57aec0..c6c2d64d 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -53,6 +53,7 @@ def __init__( num_classes: int, dropout_rate: float, image_size: int, + last_layer: nn.Module, ) -> None: super().__init__() self.dropout_rate = dropout_rate @@ -63,18 +64,21 @@ def __init__( in_channels, num_classes, kernel_size=3, padding=1 ) self.dropout = nn.Dropout(p=dropout_rate) + self.last_layer = last_layer def forward(self, x: Tensor) -> Tensor: - return self.dropout( - self.conv( - torch.ones( - ( - x.shape[0], - self.in_channels, - self.image_size, - self.image_size, - ), - dtype=torch.float32, + return self.last_layer( + self.dropout( + self.conv( + torch.ones( + ( + x.shape[0], + self.in_channels, + self.image_size, + self.image_size, + ), + dtype=torch.float32, + ) ) ) ) @@ -85,7 +89,7 @@ def dummy_model( num_classes: int, dropout_rate: float = 0.0, with_feats: bool = True, - last_layer=None, + last_layer: nn.Module | None = None, ) -> _Dummy: """Dummy model for testing purposes. @@ -95,7 +99,7 @@ def dummy_model( num_estimators (int): Number of estimators in the ensemble. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. with_feats (bool, optional): Whether to include features. Defaults to True. - last_layer ([type], optional): Last layer of the model. Defaults to None. + last_layer (nn.Module, optional): Last layer of the model. Defaults to None. Returns: _Dummy: Dummy model. @@ -122,6 +126,7 @@ def dummy_segmentation_model( num_classes: int, image_size: int, dropout_rate: float = 0.0, + last_layer: nn.Module | None = None, ) -> nn.Module: """Dummy segmentation model for testing purposes. @@ -130,13 +135,17 @@ def dummy_segmentation_model( num_classes (int): Number of output classes. image_size (int): Size of the input image. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. + last_layer (nn.Module, optional): Last layer of the model. Defaults to None. Returns: nn.Module: Dummy segmentation model. """ + if last_layer is None: + last_layer = nn.Identity() return _DummySegmentation( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, image_size=image_size, + last_layer=last_layer, ) diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index dfda81a4..ada85748 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -7,27 +7,46 @@ DummPixelRegressionDataModule, DummyPixelRegressionBaseline, ) +from torch_uncertainty.losses import DistributionNLLLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import PixelRegressionRoutine from torch_uncertainty.utils import TUTrainer -class TestDepth: +class TestPixelRegression: def test_one_estimator_two_classes(self): - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummPixelRegressionDataModule( - root=root, batch_size=4, output_dim=2 + root=root, batch_size=5, output_dim=3 ) model = DummyPixelRegressionBaseline( + probabilistic=False, in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, loss=nn.MSELoss(), baseline_type="single", optim_recipe=optim_cifar10_resnet18, + ema=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) + model = DummyPixelRegressionBaseline( + probabilistic=True, + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=DistributionNLLLoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, ) trainer.fit(model, dm) @@ -44,12 +63,14 @@ def test_two_estimators_one_class(self): ) model = DummyPixelRegressionBaseline( + probabilistic=False, in_channels=dm.num_channels, output_dim=dm.output_dim, image_size=dm.image_size, loss=nn.MSELoss(), baseline_type="ensemble", optim_recipe=optim_cifar10_resnet18, + swa=True, ) trainer.fit(model, dm) @@ -60,8 +81,8 @@ def test_two_estimators_one_class(self): def test_depth_errors(self): with pytest.raises(ValueError, match="output_dim must be positive"): PixelRegressionRoutine( + probabilistic=False, model=nn.Identity(), output_dim=0, loss=nn.MSELoss(), - probabilistic=False, ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 37740273..7c03ab1e 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -26,6 +26,7 @@ def test_one_estimator_one_output(self): loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", + ema=True, ) trainer.fit(model, dm) @@ -33,13 +34,15 @@ def test_one_estimator_one_output(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", + swa=True, ) trainer.fit(model, dm) @@ -63,18 +66,21 @@ def test_one_estimator_two_outputs(self): dist_type="laplace", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) @@ -94,18 +100,21 @@ def test_two_estimators_one_output(self): dist_type="nig", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) @@ -128,11 +137,12 @@ def test_two_estimators_two_outputs(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 172934b0..f749db4b 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -27,6 +27,7 @@ def test_one_estimator_two_classes(self): baseline_type="single", optim_recipe=optim_cifar10_resnet18, log_plots=True, + ema=True, ) trainer.fit(model, dm) @@ -47,6 +48,7 @@ def test_two_estimators_two_classes(self): loss=nn.CrossEntropyLoss(), baseline_type="ensemble", optim_recipe=optim_cifar10_resnet18, + swa=True, ) trainer.fit(model, dm) diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index b0d6e1b8..45e56fb3 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -22,14 +22,24 @@ def __init__( super().__init__() self.reduction = reduction - def forward(self, dist: Distribution, targets: Tensor) -> Tensor: + def forward( + self, + dist: Distribution, + targets: Tensor, + padding_mask: Tensor | None = None, + ) -> Tensor: """Compute the NLL of the targets given predicted distributions. Args: dist (Distribution): The predicted distributions targets (Tensor): The target values + padding_mask (Tensor, optional): The padding mask. Defaults to None. + Sets the loss to 0 for padded values. """ loss = -dist.log_prob(targets) + if padding_mask is not None: + loss = loss.masked_fill(padding_mask, 0.0) + if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": diff --git a/torch_uncertainty/metrics/regression/nll.py b/torch_uncertainty/metrics/regression/nll.py index 9b2f9c3a..9db4dd31 100644 --- a/torch_uncertainty/metrics/regression/nll.py +++ b/torch_uncertainty/metrics/regression/nll.py @@ -5,17 +5,27 @@ class DistributionNLL(CategoricalNLL): - def update(self, dist: distributions.Distribution, target: Tensor) -> None: + def update( + self, + dist: distributions.Distribution, + target: Tensor, + padding_mask: Tensor | None = None, + ) -> None: """Update state with the predicted distributions and the targets. Args: dist (torch.distributions.Distribution): Predicted distributions. target (Tensor): Ground truth labels. + padding_mask (Tensor, optional): The padding mask. Defaults to None. + Sets the loss to 0 for padded values. """ + nlog_prob = -dist.log_prob(target) + if padding_mask is not None: + nlog_prob = nlog_prob.masked_fill(padding_mask, 0.0) if self.reduction is None or self.reduction == "none": - self.values.append(-dist.log_prob(target)) + self.values.append(nlog_prob) else: - self.values += -dist.log_prob(target).sum() + self.values += nlog_prob.sum() self.total += target.size(0) def compute(self) -> Tensor: diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index e9edd107..b6267d31 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -8,21 +8,28 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -__all__ = [ - "optim_cifar10_resnet18", - "optim_cifar10_resnet34", - "optim_cifar10_resnet50", - "optim_cifar10_vgg16", - "optim_cifar10_wideresnet", - "optim_cifar100_resnet18", - "optim_cifar100_resnet34", - "optim_cifar100_resnet50", - "optim_cifar100_vgg16", - "optim_imagenet_resnet50", - "optim_imagenet_resnet50_a3", - "optim_tinyimagenet_resnet34", - "optim_tinyimagenet_resnet50", -] + +def optim_abnn( + model: nn.Module, + lr: float, + momentum: float = 0.9, + weight_decay: float = 1e-4, + nesterov: bool = True, +) -> dict: + """ABNN finetuning recipe.""" + optimizer = optim.SGD( + model.parameters(), + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov=nesterov, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[2, 4], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} def optim_cifar10_resnet18( diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 3aa435ec..d3400dfe 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -1,7 +1,7 @@ from typing import Literal import torch -from torch import Tensor, device, nn, optim +from torch import Tensor, nn, optim from torch.utils.data import DataLoader, Dataset from tqdm import tqdm @@ -17,7 +17,7 @@ def __init__( model: nn.Module | None = None, lr: float = 0.1, max_iter: int = 100, - device: Literal["cpu", "cuda"] | device | None = None, + device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Virtual class for scaling post-processing for calibrated probabilities. diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index c6372bdb..f334cbab 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -1,7 +1,7 @@ from typing import Literal import torch -from torch import Tensor, device, nn +from torch import Tensor, nn from .scaler import Scaler @@ -13,7 +13,7 @@ def __init__( init_val: float = 1, lr: float = 0.1, max_iter: int = 100, - device: Literal["cpu", "cuda"] | device | None = None, + device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Temperature scaling post-processing for calibrated probabilities. diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index a0c4844a..1ceb673d 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -305,7 +305,7 @@ def _apply_mixup( with torch.no_grad(): feats = self.model.feats_forward(batch[0]).detach() batch = self.mixup(*batch, feats) - elif self.mixup_params["dist_sim"] == "inp": + else: # self.mixup_params["dist_sim"] == "inp": batch = self.mixup(*batch, batch[0]) else: batch = self.mixup(*batch) diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index b5a398e1..206323b0 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -33,7 +33,11 @@ EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) -from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist +from torch_uncertainty.utils.distributions import ( + dist_rearrange, + dist_size, + dist_squeeze, +) class PixelRegressionRoutine(LightningModule): @@ -138,7 +142,7 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = self.model(inputs) if self.probabilistic: if not self.is_ensemble: - pred = squeeze_dist(pred, -1) + pred = dist_squeeze(pred, -1) else: if not self.is_ensemble: pred = pred.squeeze(-1) @@ -152,11 +156,21 @@ def training_step( target = target.unsqueeze(1) dists = self.model(inputs) + if self.probabilistic: + out_shape = dist_size(dists)[-2:] + else: + out_shape = dists.shape[-2:] target = F.resize( - target, dists.shape[-2:], interpolation=F.InterpolationMode.NEAREST + target, out_shape, interpolation=F.InterpolationMode.NEAREST ) - valid_mask = ~torch.isnan(target) - loss = self.loss(dists[valid_mask], target[valid_mask]) + padding_mask = torch.isnan(target) + if self.probabilistic: + loss = self.loss(dists, target, padding_mask) + else: + loss = self.loss(dists[padding_mask], target[padding_mask]) + + if self.need_step_update: + self.model.update_model(self.current_epoch) self.log("train_loss", loss) return loss @@ -166,24 +180,26 @@ def validation_step( inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c h w -> (b c h w)") preds = self.model(inputs) if self.probabilistic: ens_dist = Independent( dist_rearrange( - preds, "(m b) c h w -> b m c h w", b=targets.size(0) + preds, "(m b) c h w -> (b c h w) m", b=batch_size ), - 1, + 0, ) mix = Categorical( - torch.ones(preds.size(0) // targets.size(0), device=self.device) + torch.ones( + (dist_size(preds)[0] // batch_size), device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange( - preds, "(m b) c h w -> b m c h w", b=targets.size(0) - ) + preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) if batch_idx == 0: @@ -194,12 +210,10 @@ def validation_step( stage="val", ) - valid_mask = ~torch.isnan(targets) - self.val_metrics.update(preds[valid_mask], targets[valid_mask]) + padding_mask = torch.isnan(targets) + self.val_metrics.update(preds[padding_mask], targets[padding_mask]) if self.probabilistic: - self.val_prob_metrics.update( - mixture[valid_mask], targets[valid_mask] - ) + self.val_prob_metrics.update(mixture, targets, padding_mask) def test_step( self, @@ -212,26 +226,26 @@ def test_step( "Depth OOD detection not implemented yet. Raise an issue " "if needed." ) - inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c h w -> (b c h w)") preds = self.model(inputs) if self.probabilistic: ens_dist = dist_rearrange( - preds, "(m b) c h w -> b m c h w", b=targets.size(0) + preds, "(m b) c h w -> (b c h w) m", b=batch_size ) mix = Categorical( - torch.ones(preds.size(0) // targets.size(0), device=self.device) + torch.ones( + (dist_size(preds)[0] // batch_size), device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) - self.test_metrics.nll.update(mixture, targets) preds = mixture.mean else: - preds = rearrange( - preds, "(m b) c h w -> b m c h w", b=targets.size(0) - ) + preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) if batch_idx == 0: @@ -247,12 +261,10 @@ def test_step( stage="test", ) - valid_mask = ~torch.isnan(targets) - self.test_metrics.update(preds[valid_mask], targets[valid_mask]) + padding_mask = torch.isnan(targets) + self.test_metrics.update(preds[padding_mask], targets[padding_mask]) if self.probabilistic: - self.test_prob_metrics.update( - mixture[valid_mask], targets[valid_mask] - ) + self.test_prob_metrics.update(mixture, targets, padding_mask) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute(), sync_dist=True) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index dbbb06ed..3778b6d8 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -6,7 +6,6 @@ from torch.distributions import ( Categorical, Distribution, - Independent, MixtureSameFamily, ) from torch.optim import Optimizer @@ -21,8 +20,8 @@ ) from torch_uncertainty.utils.distributions import ( dist_rearrange, - size_dist, - squeeze_dist, + dist_size, + dist_squeeze, ) @@ -141,9 +140,9 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = self.model(inputs) if self.probabilistic: if self.one_dim_regression: - pred = squeeze_dist(pred, -1) + pred = dist_squeeze(pred, -1) if not self.is_ensemble: - pred = squeeze_dist(pred, -1) + pred = dist_squeeze(pred, -1) else: if self.one_dim_regression: pred = pred.squeeze(-1) @@ -161,6 +160,8 @@ def training_step( targets = targets.unsqueeze(-1) loss = self.loss(dists, targets) + if self.need_step_update: + self.model.update_model(self.current_epoch) self.log("train_loss", loss) return loss @@ -170,22 +171,22 @@ def validation_step( inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c -> (b c)") preds = self.model(inputs) if self.probabilistic: - ens_dist = Independent( - dist_rearrange(preds, "(m b) c -> b m c", b=targets.size(0)), - 1, - ) + ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) mix = Categorical( torch.ones( - size_dist(preds)[0] // targets.size(0), device=self.device + dist_size(preds)[0] // batch_size, device=self.device ) ) + print(ens_dist, type(ens_dist)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange(preds, "(m b) c -> b m c", b=targets.size(0)) + preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size) preds = preds.mean(dim=1) self.val_metrics.update(preds, targets) @@ -207,22 +208,21 @@ def test_step( inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c -> (b c)") preds = self.model(inputs) if self.probabilistic: - ens_dist = Independent( - dist_rearrange(preds, "(m b) c -> b m c", b=targets.size(0)), - 1, - ) + ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) mix = Categorical( torch.ones( - size_dist(preds)[0] // targets.size(0), device=self.device + dist_size(preds)[0] // batch_size, device=self.device ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange(preds, "(m b) c -> b m c", b=targets.size(0)) + preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size) preds = preds.mean(dim=1) self.test_metrics.update(preds, targets) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index a5d89a4b..dddb9842 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -159,6 +159,8 @@ def training_step( target = target.flatten() valid_mask = target != 255 loss = self.loss(logits[valid_mask], target[valid_mask]) + if self.need_step_update: + self.model.update_model(self.current_epoch) self.log("train_loss", loss) return loss diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index aa3d1043..b5d0be87 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -3,11 +3,24 @@ import torch from einops import rearrange from torch import Tensor -from torch.distributions import Distribution, Laplace, Normal, constraints +from torch.distributions import ( + Distribution, + Laplace, + Normal, + constraints, +) from torch.distributions.utils import broadcast_all -def size_dist(distribution: Distribution) -> torch.Size: +def dist_size(distribution: Distribution) -> torch.Size: + """Get the size of the distribution. + + Args: + distribution (Distribution): The distribution. + + Returns: + torch.Size: The size of the distribution. + """ if isinstance(distribution, Normal | Laplace | NormalInverseGamma): return distribution.loc.size() raise NotImplementedError( @@ -53,14 +66,16 @@ def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: betas = torch.cat( [distribution.beta for distribution in distributions], dim=dim ) - return dist_type(loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas) + return NormalInverseGamma( + loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas + ) raise NotImplementedError( f"Concatenation of {dist_type} distributions is not supported." "Raise an issue if needed." ) -def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: +def dist_squeeze(distribution: Distribution, dim: int) -> Distribution: """Squeeze the distribution along a given dimension. Args: @@ -71,16 +86,16 @@ def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: Distribution: The squeezed distribution. """ dist_type = type(distribution) - if dist_type in (Normal, Laplace): + if isinstance(distribution, Normal | Laplace): loc = distribution.loc.squeeze(dim) scale = distribution.scale.squeeze(dim) return dist_type(loc=loc, scale=scale) - if dist_type == NormalInverseGamma: + if isinstance(distribution, NormalInverseGamma): loc = distribution.loc.squeeze(dim) lmbda = distribution.lmbda.squeeze(dim) alpha = distribution.alpha.squeeze(dim) beta = distribution.beta.squeeze(dim) - return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) + return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( f"Squeezing of {dist_type} distributions is not supported." "Raise an issue if needed." @@ -91,19 +106,19 @@ def dist_rearrange( distribution: Distribution, pattern: str, **axes_lengths: int ) -> Distribution: dist_type = type(distribution) - if dist_type in (Normal, Laplace): + print(dist_type) + if isinstance(distribution, Normal | Laplace): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) return dist_type(loc=loc, scale=scale) - if dist_type == NormalInverseGamma: + if isinstance(distribution, NormalInverseGamma): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) lmbda = rearrange(distribution.lmbda, pattern=pattern, **axes_lengths) alpha = rearrange(distribution.alpha, pattern=pattern, **axes_lengths) beta = rearrange(distribution.beta, pattern=pattern, **axes_lengths) - return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) + return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( - f"Ensemble distribution of {dist_type} is not supported." - "Raise an issue if needed." + f"Rearrange of {dist_type} is not supported. Raise an issue if needed." ) From c9e0404846461129b8d76287658444dc2b588cbc Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 20:18:56 +0200 Subject: [PATCH 50/57] :sparkles: Add first version for ABNN --- torch_uncertainty/layers/bayesian/abnn.py | 67 +++++++ torch_uncertainty/post_processing/abnn.py | 210 ++++++++++++++++++++++ 2 files changed, 277 insertions(+) create mode 100644 torch_uncertainty/layers/bayesian/abnn.py create mode 100644 torch_uncertainty/post_processing/abnn.py diff --git a/torch_uncertainty/layers/bayesian/abnn.py b/torch_uncertainty/layers/bayesian/abnn.py new file mode 100644 index 00000000..dcf75122 --- /dev/null +++ b/torch_uncertainty/layers/bayesian/abnn.py @@ -0,0 +1,67 @@ +import torch +from torch import Tensor, nn +from torch.nn import functional as F + + +class BatchNormAdapter2d(nn.Module): + def __init__( + self, + num_features: int, + alpha: float = 0.1, + momentum: float = 0.1, + eps: float = 1e-5, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + self.weight = nn.Parameter( + torch.ones(num_features, device=device, dtype=dtype), + requires_grad=True, + ) + self.bias = nn.Parameter( + torch.zeros(num_features, device=device, dtype=dtype), + requires_grad=True, + ) + + self.register_buffer( + "running_mean", + torch.zeros(num_features, device=device, dtype=dtype), + ) + self.register_buffer( + "running_var", + torch.zeros(num_features, device=device, dtype=dtype), + ) + self.register_buffer( + "num_batches_tracked", + torch.tensor(0, dtype=torch.long, device=device), + ) + self.alpha = alpha + self.momentum = momentum + self.eps = eps + self.frozen = False + + def forward(self, x: Tensor) -> Tensor: + if self.frozen: + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training, + self.momentum, + self.eps, + ) + out = F.batch_norm( + x, + self.running_mean, + self.running_var, + None, + None, + self.training, + self.momentum, + self.eps, + ) + return self.weight.unsqueeze(-1).unsqueeze(-1) * out * ( + torch.randn_like(x) * self.alpha + 1 + ) + self.bias.unsqueeze(-1).unsqueeze(-1) diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py new file mode 100644 index 00000000..815d06df --- /dev/null +++ b/torch_uncertainty/post_processing/abnn.py @@ -0,0 +1,210 @@ +import copy + +import torch +from torch import Tensor, nn +from torch.utils.data import DataLoader, Dataset + +from torch_uncertainty.layers.bayesian.abnn import BatchNormAdapter2d +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.optim_recipes import optim_abnn +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TUTrainer + +from .abstract import PostProcessing + + +class ABNN(PostProcessing): + def __init__( + self, + num_classes: int, + random_prior: float, + alpha: float, + num_models: int, + num_samples: int, + base_lr: float, + device: torch.device, + max_epochs: int = 5, + use_original_model: bool = True, + batch_size: int = 128, + precision: str = "32", + model: nn.Module | None = None, + ): + """ABNN post-processing. + + Args: + num_classes (int): Number of classes of the inner model. + random_prior (float): Random prior specializing estimators on + certain classes. + alpha (float): Alpha value for ABNN to control the diversity of + the predictions. + num_models (int): Number of stochastic models. + num_samples (int): Number of samples per model. + base_lr (float): Base learning rate. + device (torch.device): Device to use. + max_epochs (int, optional): Number of training epochs. Defaults + to 5. + use_original_model (bool, optional): Use original model during + evaluation. Defaults to True. + batch_size (int, optional): Batch size for the training of ABNN. + Defaults to 128. + precision (str, optional): Machine precision for training & eval. + Defaults to "32". + model (nn.Module | None, optional): Model to use. Defaults to None. + + Reference: + + """ + super().__init__(model) + _abnn_checks( + num_classes=num_classes, + random_prior=random_prior, + alpha=alpha, + max_epochs=max_epochs, + num_models=num_models, + num_samples=num_samples, + base_lr=base_lr, + batch_size=batch_size, + ) + self.num_classes = num_classes + self.alpha = alpha + self.base_lr = base_lr + self.num_models = num_models + self.num_samples = num_samples + self.total_models = num_models + int(use_original_model) + self.use_original_model = use_original_model + self.max_epochs = max_epochs + + self.batch_size = batch_size + self.precision = precision + self.device = device + + self.final_model = None + + # Build random prior + num_rp_classes = int(num_classes**0.5) + self.weights = [] + for _ in range(num_models): + weight = torch.ones([num_classes]) + weight[torch.randperm(num_classes)[:num_rp_classes]] += ( + random_prior - 1 + ) + self.weights.append(weight) + + def fit(self, dataset: Dataset) -> None: + if self.model is None: + raise ValueError("Model must be set before fitting.") + dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + source_model = copy.deepcopy(self.model) + _replace_bn_layers(source_model, self.alpha) + + models = [copy.deepcopy(source_model) for _ in range(self.num_models)] + + baselines = [ + ClassificationRoutine( + num_classes=self.num_classes, + model=mod, + loss=nn.CrossEntropyLoss( + weight=self.weights[i].to(device=self.device) + ), + optim_recipe=optim_abnn(mod, lr=self.base_lr), + eval_ood=True, + ) + for i, mod in enumerate(models) + ] + + for baseline in baselines: + trainer = TUTrainer( + max_epochs=self.max_epochs, + devices=self.device, + enable_progress_bar=False, + precision=self.precision, + logger=None, + ) + trainer.fit(model=baseline, train_dataloaders=dl) + + final_models = ( + [copy.deepcopy(source_model) for _ in range(self.num_samples)] + if self.use_original_model + else [] + ) + for baseline in baselines: + model = copy.deepcopy(source_model) + model.load_state_dict(baseline.model.state_dict()) + final_models.extend( + [copy.deepcopy(model) for _ in range(self.num_samples)] + ) + + self.final_model = deep_ensembles(final_models) + + def forward( + self, + x: Tensor, + ) -> Tensor: + if self.final_model is not None: + return self.final_model(x) + if self.model is not None: + return self.model(x) + raise ValueError("Model must be set before calling forward.") + + +def _abnn_checks( + num_classes, + random_prior, + alpha, + max_epochs, + num_models, + num_samples, + base_lr, + batch_size, +) -> None: + if random_prior < 0: + raise ValueError( + f"random_prior must be greater than 0. Got {random_prior}." + ) + if batch_size < 1: + raise ValueError( + f"batch_size must be greater than 0. Got {batch_size}." + ) + if max_epochs < 1: + raise ValueError(f"epoch must be greater than 0. Got {max_epochs}.") + if num_models < 1: + raise ValueError( + f"num_models must be greater than 0. Got {num_models}." + ) + if num_samples < 1: + raise ValueError( + f"num_samples must be greater than 0. Got {num_samples}." + ) + if alpha < 0: + raise ValueError(f"alpha must be greater than 0. Got {alpha}.") + if base_lr < 0: + raise ValueError(f"base_lr must be greater than 0. Got {base_lr}.") + if num_classes < 1: + raise ValueError( + f"num_classes must be greater than 0. Got {num_classes}." + ) + + +def _replace_bn_layers(model: nn.Module, alpha: float) -> None: + """Recursively replace batch normalization layers with ABNN layers. + + Args: + model (nn.Module): Model to replace batch normalization layers. + alpha (float): Alpha value for ABNN. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + _replace_bn_layers(module, alpha) + if isinstance(module, nn.BatchNorm2d) and module.track_running_stats: + num_channels = module.num_features + new_module = BatchNormAdapter2d(num_channels, alpha=alpha) + new_module.running_mean = module.running_mean + new_module.running_var = module.running_var + new_module.num_batches_tracked = module.num_batches_tracked + + new_module.weight.data = module.weight.data + new_module.bias.data = module.bias.data + setattr(model, name, new_module) + else: + _replace_bn_layers(module, alpha) From 4fe4aece879deb39027ccb10f3a9fe2f8b0a6ff4 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 21:58:45 +0200 Subject: [PATCH 51/57] :white_check_mark: Improve coverage --- tests/_dummies/__init__.py | 2 +- tests/_dummies/baseline.py | 6 ++-- tests/_dummies/datamodule.py | 2 +- tests/routines/test_classification.py | 25 +++++++++++++ tests/routines/test_pixel_regression.py | 37 ++++++++++++++++---- tests/routines/test_segmentation.py | 2 +- tests/test_optim_recipes.py | 3 +- torch_uncertainty/post_processing/abnn.py | 6 ++-- torch_uncertainty/routines/classification.py | 30 ++++++++-------- torch_uncertainty/utils/hub.py | 6 ++-- 10 files changed, 85 insertions(+), 34 deletions(-) diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index 8dec1d03..4f2df70d 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -6,8 +6,8 @@ DummySegmentationBaseline, ) from .datamodule import ( - DummPixelRegressionDataModule, DummyClassificationDataModule, + DummyPixelRegressionDataModule, DummyRegressionDataModule, DummySegmentationDataModule, ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 3c39376b..615ca9da 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -80,9 +80,7 @@ def __new__( ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, - post_processing_method=TemperatureScaler() - if calibrate - else None, + post_processing=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, ) # baseline_type == "ensemble": @@ -101,7 +99,7 @@ def __new__( ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, - post_processing_method=TemperatureScaler() if calibrate else None, + post_processing=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 7e34a92b..9675dd19 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -249,7 +249,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummPixelRegressionDataModule(BaseDataModule): +class DummyPixelRegressionDataModule(BaseDataModule): num_channels = 3 training_task = "pixel_regression" diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 888087f4..999a9663 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -10,6 +10,7 @@ ) from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.transforms import RepeatTarget from torch_uncertainty.utils import TUTrainer @@ -391,3 +392,27 @@ def test_classification_failures(self): ClassificationRoutine( num_classes=10, model=model, loss=None, eval_grouping_loss=True ) + + with pytest.raises( + ValueError, + match="Mixup is not supported for ensembles at training time", + ): + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + mixup_params={"mixtype": "mixup"}, + format_batch_fn=RepeatTarget(2), + ) + + with pytest.raises( + ValueError, + match="Ensembles and post-processing methods cannot be used together. Raise an issue if needed.", + ): + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + is_ensemble=True, + post_processing=nn.Module(), + ) diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index ada85748..8db1af34 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -1,24 +1,28 @@ from pathlib import Path import pytest +import torch from torch import nn from tests._dummies import ( - DummPixelRegressionDataModule, DummyPixelRegressionBaseline, + DummyPixelRegressionDataModule, ) from torch_uncertainty.losses import DistributionNLLLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 -from torch_uncertainty.routines import PixelRegressionRoutine +from torch_uncertainty.routines.pixel_regression import ( + PixelRegressionRoutine, + colorize, +) from torch_uncertainty.utils import TUTrainer class TestPixelRegression: def test_one_estimator_two_classes(self): - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) + trainer = TUTrainer(accelerator="cpu", max_epochs=1, logger=None) root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummPixelRegressionDataModule( + dm = DummyPixelRegressionDataModule( root=root, batch_size=5, output_dim=3 ) @@ -38,7 +42,7 @@ def test_one_estimator_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) + trainer = TUTrainer(accelerator="cpu", max_epochs=1, logger=None) model = DummyPixelRegressionBaseline( probabilistic=True, in_channels=dm.num_channels, @@ -47,6 +51,7 @@ def test_one_estimator_two_classes(self): loss=DistributionNLLLoss(), baseline_type="single", optim_recipe=optim_cifar10_resnet18, + swa=True, ) trainer.fit(model, dm) @@ -58,7 +63,7 @@ def test_two_estimators_one_class(self): trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummPixelRegressionDataModule( + dm = DummyPixelRegressionDataModule( root=root, batch_size=4, output_dim=1 ) @@ -70,7 +75,6 @@ def test_two_estimators_one_class(self): loss=nn.MSELoss(), baseline_type="ensemble", optim_recipe=optim_cifar10_resnet18, - swa=True, ) trainer.fit(model, dm) @@ -78,6 +82,25 @@ def test_two_estimators_one_class(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) + model = DummyPixelRegressionBaseline( + probabilistic=True, + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=DistributionNLLLoss(), + baseline_type="ensemble", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + colorize(torch.ones((10, 10)), 0, 1) + colorize(torch.ones((10, 10)), 0, 0) + def test_depth_errors(self): with pytest.raises(ValueError, match="output_dim must be positive"): PixelRegressionRoutine( diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index f749db4b..8f6ed490 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -36,7 +36,7 @@ def test_one_estimator_two_classes(self): model(dm.get_test_set()[0][0]) def test_two_estimators_two_classes(self): - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer(accelerator="cpu", max_epochs=1, logger=None) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) diff --git a/tests/test_optim_recipes.py b/tests/test_optim_recipes.py index a3e49530..b6d15863 100644 --- a/tests/test_optim_recipes.py +++ b/tests/test_optim_recipes.py @@ -2,7 +2,7 @@ import pytest import torch -from torch_uncertainty.optim_recipes import FullSWALR, get_procedure +from torch_uncertainty.optim_recipes import FullSWALR, get_procedure, optim_abnn class TestFullSWALR: @@ -23,6 +23,7 @@ def test_optim_cifar10(self): get_procedure("resnet50", "cifar10", "packed")(model) get_procedure("wideresnet28x10", "cifar10", "batched")(model) get_procedure("vgg16", "cifar10", "standard")(model) + optim_abnn(model, lr=0.1) def test_optim_cifar100(self): model = torch.nn.Linear(1, 1) diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 815d06df..0cb00d7b 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -22,7 +22,7 @@ def __init__( num_models: int, num_samples: int, base_lr: float, - device: torch.device, + device: torch.device | str, max_epochs: int = 5, use_original_model: bool = True, batch_size: int = 128, @@ -116,10 +116,12 @@ def fit(self, dataset: Dataset) -> None: for baseline in baselines: trainer = TUTrainer( max_epochs=self.max_epochs, - devices=self.device, + accelerator=self.device, enable_progress_bar=False, precision=self.precision, logger=None, + weights_summary=None, + progress_bar_refresh_rate=0, ) trainer.fit(model=baseline, train_dataloaders=dl) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 1ceb673d..1a7df33e 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -73,7 +73,7 @@ def __init__( ood_criterion: Literal[ "msp", "logit", "energy", "entropy", "mi", "vr" ] = "msp", - post_processing_method: PostProcessing | None = None, + post_processing: PostProcessing | None = None, calibration_set: Literal["val", "test"] = "val", num_calibration_bins: int = 15, log_plots: bool = False, @@ -106,7 +106,7 @@ def __init__( - ``"entropy"``: Entropy of the mean prediction. - ``"mi"``: Mutual information of the ensemble. - ``"vr"``: Variation ratio of the ensemble. - post_processing_method (PostProcessing, optional): Post-processing method + post_processing (PostProcessing, optional): Post-processing method to train on the calibration set. No post-processing if None. Defaults to ``None``. calibration_set (str, optional): The post-hoc calibration dataset to @@ -135,7 +135,7 @@ def __init__( eval_grouping_loss=eval_grouping_loss, num_calibration_bins=num_calibration_bins, mixup_params=mixup_params, - post_processing_method=post_processing_method, + post_processing=post_processing, format_batch_fn=format_batch_fn, ) @@ -159,9 +159,9 @@ def __init__( self.optim_recipe = optim_recipe self.is_ensemble = is_ensemble - self.post_processing_method = post_processing_method - if self.post_processing_method is not None: - self.post_processing_method.set_model(self.model) + self.post_processing = post_processing + if self.post_processing is not None: + self.post_processing.set_model(self.model) self._init_metrics() self.mixup = self._init_mixup(mixup_params) @@ -209,7 +209,7 @@ def _init_metrics(self) -> None: self.val_cls_metrics = cls_metrics.clone(prefix="val/") self.test_cls_metrics = cls_metrics.clone(prefix="test/") - if self.calibration_set is not None: + if self.post_processing is not None: self.ts_cls_metrics = cls_metrics.clone(prefix="test/ts_") self.test_id_entropy = Entropy() @@ -329,14 +329,14 @@ def on_validation_start(self) -> None: ) def on_test_start(self) -> None: - if self.post_processing_method is not None: + if self.post_processing is not None: calibration_dataset = ( self.trainer.datamodule.val_dataloader().dataset if self.calibration_set == "val" else self.trainer.datamodule.test_dataloader()[0].dataset ) with torch.inference_mode(False): - self.post_processing_method.fit(calibration_dataset) + self.post_processing.fit(calibration_dataset) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] @@ -447,8 +447,8 @@ def test_step( else: ood_scores = -confs - if self.post_processing_method is not None: - pp_logits = self.post_processing_method(inputs) + if self.post_processing is not None: + pp_logits = self.post_processing(inputs) pp_probs = F.softmax(pp_logits, dim=-1) self.ts_cls_metrics.update(pp_probs, targets) @@ -515,7 +515,7 @@ def on_test_epoch_end(self) -> None: {"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True ) - if self.post_processing_method is not None: + if self.post_processing is not None: tmp_metrics = self.ts_cls_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -553,7 +553,7 @@ def on_test_epoch_end(self) -> None: self.test_cls_metrics["sc/AURC"].plot()[0], ) - if self.post_processing_method is not None: + if self.post_processing is not None: self.logger.experiment.add_figure( "Reliabity diagram after calibration", self.ts_cls_metrics["cal/ECE"].plot()[0], @@ -607,7 +607,7 @@ def _classification_routine_checks( eval_grouping_loss: bool, num_calibration_bins: int, mixup_params: dict | None, - post_processing_method: PostProcessing | None, + post_processing: PostProcessing | None, format_batch_fn: nn.Module | None, ) -> None: if ood_criterion not in [ @@ -664,7 +664,7 @@ def _classification_routine_checks( "Mixup is not supported for ensembles at training time. Please set mixup_params to None." ) - if post_processing_method is not None and is_ensemble: + if post_processing is not None and is_ensemble: raise ValueError( "Ensembles and post-processing methods cannot be used together. Raise an issue if needed." ) diff --git a/torch_uncertainty/utils/hub.py b/torch_uncertainty/utils/hub.py index b48bc324..9d86a22a 100644 --- a/torch_uncertainty/utils/hub.py +++ b/torch_uncertainty/utils/hub.py @@ -7,7 +7,9 @@ from safetensors.torch import load_file -def load_hf(weight_id: str, version: int = 0) -> tuple[torch.Tensor, dict]: +def load_hf( + weight_id: str, version: int = 0 +) -> tuple[dict[str, torch.Tensor], dict]: """Load a model from the HuggingFace hub. Args: @@ -15,7 +17,7 @@ def load_hf(weight_id: str, version: int = 0) -> tuple[torch.Tensor, dict]: version (int): The id of the version when there are several on HF. Returns: - Tuple[Tensor, Dict]: The model weights and config. + Tuple[dict, dict]: The model weights and config. Note - License: TorchUncertainty's weights are released under the Apache 2.0 license. From 7a5758699dcc4aa8d4d227eb3337f2cfde370261 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Jun 2024 22:04:03 +0200 Subject: [PATCH 52/57] :wrench: Lock plt version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a8ab9e8c..1bf14fef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "scipy", "huggingface-hub", "scikit-learn", - "matplotlib", + "matplotlib==3.5.2", "numpy<2", "opencv-python", "glest==0.0.1a0", From 0a3a5a72b28de1e0ec9282c0237361612f0b93f3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 19 Jun 2024 10:49:45 +0200 Subject: [PATCH 53/57] :bug: Minor changes --- tests/routines/test_pixel_regression.py | 14 ++++++++++++-- tests/routines/test_segmentation.py | 7 ++++++- torch_uncertainty/post_processing/abnn.py | 3 +-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index 8db1af34..44055669 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -19,7 +19,12 @@ class TestPixelRegression: def test_one_estimator_two_classes(self): - trainer = TUTrainer(accelerator="cpu", max_epochs=1, logger=None) + trainer = TUTrainer( + accelerator="cpu", + max_epochs=1, + logger=None, + enable_checkpointing=False, + ) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummyPixelRegressionDataModule( @@ -42,7 +47,12 @@ def test_one_estimator_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) - trainer = TUTrainer(accelerator="cpu", max_epochs=1, logger=None) + trainer = TUTrainer( + accelerator="cpu", + max_epochs=1, + logger=None, + enable_checkpointing=False, + ) model = DummyPixelRegressionBaseline( probabilistic=True, in_channels=dm.num_channels, diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 8f6ed490..19c323bf 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -36,7 +36,12 @@ def test_one_estimator_two_classes(self): model(dm.get_test_set()[0][0]) def test_two_estimators_two_classes(self): - trainer = TUTrainer(accelerator="cpu", max_epochs=1, logger=None) + trainer = TUTrainer( + accelerator="cpu", + max_epochs=1, + logger=None, + enable_checkpointing=False, + ) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 0cb00d7b..4d68469a 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -119,9 +119,8 @@ def fit(self, dataset: Dataset) -> None: accelerator=self.device, enable_progress_bar=False, precision=self.precision, + enable_checkpointing=False, logger=None, - weights_summary=None, - progress_bar_refresh_rate=0, ) trainer.fit(model=baseline, train_dataloaders=dl) From 41f2f8009fc089e0332cc4f9f3e46c0550afca57 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 19 Jun 2024 16:06:10 +0200 Subject: [PATCH 54/57] :fire: Remove webdataset --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1bf14fef..31db7c62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ ] [project.optional-dependencies] -image = ["scikit-image", "h5py", "webdataset"] +image = ["scikit-image", "h5py"] tabular = ["pandas"] dev = [ "torch_uncertainty[image]", @@ -67,7 +67,7 @@ docs = [ all = [ "torch_uncertainty[dev,docs,image,tabular]", "laplace-torch" - ] +] [project.urls] homepage = "https://torch-uncertainty.github.io/" From 17c3071c154e8ebda71aa42b607b1e97d971022e Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 21 Jun 2024 10:35:42 +0200 Subject: [PATCH 55/57] :books: Improve API Page --- docs/source/api.rst | 84 +++++++++++++++++-- torch_uncertainty/datasets/__init__.py | 1 + .../datasets/classification/__init__.py | 2 +- .../classification/imagenet/__init__.py | 1 + .../datasets/classification/imagenet/base.py | 6 +- .../classification/imagenet/imagenet_c.py | 1 - .../datasets/{classification => }/fractals.py | 0 torch_uncertainty/routines/classification.py | 2 +- .../routines/pixel_regression.py | 34 ++++++-- torch_uncertainty/routines/regression.py | 2 +- torch_uncertainty/routines/segmentation.py | 2 +- 11 files changed, 116 insertions(+), 19 deletions(-) rename torch_uncertainty/datasets/{classification => }/fractals.py (100%) diff --git a/docs/source/api.rst b/docs/source/api.rst index ba4e3ef5..4eafd415 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -22,25 +22,25 @@ Classification ClassificationRoutine -Regression -^^^^^^^^^^ +Segmentation +^^^^^^^^^^^^ .. autosummary:: :toctree: generated/ :nosignatures: :template: class.rst - RegressionRoutine + SegmentationRoutine -Segmentation -^^^^^^^^^^^^ +Regression +^^^^^^^^^^ .. autosummary:: :toctree: generated/ :nosignatures: :template: class.rst - SegmentationRoutine + RegressionRoutine Pixelwise Regression ^^^^^^^^^^^^^^^^^^^^ @@ -244,6 +244,7 @@ Post-Processing Methods :toctree: generated/ :nosignatures: :template: class.rst + MCBatchNorm LaplaceApprox @@ -308,3 +309,74 @@ Segmentation CamVidDataModule CityscapesDataModule MUADDataModule + +Datasets +-------- + +.. currentmodule:: torch_uncertainty.datasets + +Classification +^^^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets.classification + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + MNISTC + NotMNIST + CIFAR10C + CIFAR100C + CIFAR10H + CIFAR10N + CIFAR100N + ImageNetA + ImageNetC + ImageNetO + ImageNetR + TinyImageNet + TinyImageNetC + OpenImageO + +Regression +^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets.regression + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + UCIRegression + +Segmentation +^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets.segmentation + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + CamVid + Cityscapes + +Others & Cross-Categories +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + Fractals + FrostImages + KITTIDepth + MUAD + NYUv2 diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index d6a02df0..5acc7735 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset +from .fractals import Fractals from .frost import FrostImages from .kitti import KITTIDepth from .muad import MUAD diff --git a/torch_uncertainty/datasets/classification/__init__.py b/torch_uncertainty/datasets/classification/__init__.py index 07f228ea..9bce03a1 100644 --- a/torch_uncertainty/datasets/classification/__init__.py +++ b/torch_uncertainty/datasets/classification/__init__.py @@ -1,8 +1,8 @@ # ruff: noqa: F401 from .cifar import CIFAR10C, CIFAR10H, CIFAR10N, CIFAR100C, CIFAR100N -from .fractals import Fractals from .imagenet import ( ImageNetA, + ImageNetC, ImageNetO, ImageNetR, TinyImageNet, diff --git a/torch_uncertainty/datasets/classification/imagenet/__init__.py b/torch_uncertainty/datasets/classification/imagenet/__init__.py index 9abfd040..f5971ff5 100644 --- a/torch_uncertainty/datasets/classification/imagenet/__init__.py +++ b/torch_uncertainty/datasets/classification/imagenet/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .imagenet_a import ImageNetA +from .imagenet_c import ImageNetC from .imagenet_o import ImageNetO from .imagenet_r import ImageNetR from .tiny_imagenet import TinyImageNet diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index 891bfb9a..7d69d0f9 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -26,9 +26,9 @@ class ImageNetVariation(ImageFolder): downloaded, it is not downloaded again. Defaults to False. """ - url: str - filename: str - tgz_md5: str + url: str | list[str] + filename: str | list[str] + tgz_md5: str | list[str] dataset_name: str root_appendix: str diff --git a/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py index 10e3df58..c95e1188 100644 --- a/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py @@ -1,7 +1,6 @@ from .base import ImageNetVariation -# todo, build or download class ImageNetC(ImageNetVariation): """The corrupted ImageNet-C dataset. diff --git a/torch_uncertainty/datasets/classification/fractals.py b/torch_uncertainty/datasets/fractals.py similarity index 100% rename from torch_uncertainty/datasets/classification/fractals.py rename to torch_uncertainty/datasets/fractals.py diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 1a7df33e..814ab28d 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -79,7 +79,7 @@ def __init__( log_plots: bool = False, save_in_csv: bool = False, ) -> None: - r"""Routine for training & testing on **classification tasks**. + r"""Routine for training & testing on **classification** tasks. Args: model (torch.nn.Module): Model to train. diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 206323b0..7181d276 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -53,12 +53,31 @@ def __init__( probabilistic: bool, loss: nn.Module, is_ensemble: bool = False, - optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, num_image_plot: int = 4, + log_plots: bool = False, ) -> None: + """Routine for training & testing on **pixel regression** tasks. + + Args: + model (nn.Module): Model to train. + output_dim (int): Number of outputs of the model. + probabilistic (bool): Whether the model is probabilistic, i.e., + outputs a PyTorch distribution. + loss (nn.Module): Loss function to optimize the :attr:`model`. + is_ensemble (bool, optional): Whether the model is an ensemble. + Defaults to ``False``. + optim_recipe (dict or Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. + format_batch_fn (nn.Module, optional): The function to format the + batch. Defaults to ``None``. + num_image_plot (int, optional): Number of images to plot. Defaults to ``4``. + log_plots (bool, optional): Indicates whether to log plots from + metrics. Defaults to ``False``. + """ super().__init__() - _depth_routine_checks(output_dim) + _depth_routine_checks(output_dim, num_image_plot) self.model = model self.output_dim = output_dim @@ -67,6 +86,7 @@ def __init__( self.loss = loss self.num_image_plot = num_image_plot self.is_ensemble = is_ensemble + self.log_plots = log_plots self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) @@ -202,7 +222,7 @@ def validation_step( preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) - if batch_idx == 0: + if batch_idx == 0 and self.log_plots: self._plot_depth( inputs[: self.num_image_plot, ...], preds[: self.num_image_plot, ...], @@ -248,7 +268,7 @@ def test_step( preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) - if batch_idx == 0: + if batch_idx == 0 and self.log_plots: num_images = ( self.num_image_plot if self.num_image_plot < inputs.size(0) @@ -344,6 +364,10 @@ def colorize( return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 -def _depth_routine_checks(output_dim: int) -> None: +def _depth_routine_checks(output_dim: int, num_image_plot) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") + if num_image_plot < 1: + raise ValueError( + f"num_image_plot must be positive, got {num_image_plot}." + ) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 3778b6d8..9078ee02 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -36,7 +36,7 @@ def __init__( optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: - r"""Routine for training & testing on **regression tasks**. + r"""Routine for training & testing on **regression** tasks. Args: model (torch.nn.Module): Model to train. diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index dddb9842..9504e518 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -33,7 +33,7 @@ def __init__( log_plots: bool = False, num_calibration_bins: int = 15, ) -> None: - r"""Routine for training & testing on segmentation tasks. + r"""Routine for training & testing on **segmentation** tasks. Args: model (torch.nn.Module): Model to train. From f56866abf1735f73d769376875f8586a6cf766c2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 21 Jun 2024 10:44:56 +0200 Subject: [PATCH 56/57] :white_check_mark: Slightly improve tests --- tests/_dummies/baseline.py | 2 +- tests/routines/test_segmentation.py | 2 +- torch_uncertainty/post_processing/abnn.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 615ca9da..535cd567 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -188,7 +188,7 @@ def __new__( if ema: model = EMA(model, momentum=0.99) if swa: - model = SWA(model, cycle_start=0, cycle_length=1) + model = SWA(model, cycle_start=0, cycle_length=2) if baseline_type == "single": return SegmentationRoutine( diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 19c323bf..0a3e822b 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -38,7 +38,7 @@ def test_one_estimator_two_classes(self): def test_two_estimators_two_classes(self): trainer = TUTrainer( accelerator="cpu", - max_epochs=1, + max_epochs=2, logger=None, enable_checkpointing=False, ) diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 4d68469a..9dd4e79d 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -121,6 +121,7 @@ def fit(self, dataset: Dataset) -> None: precision=self.precision, enable_checkpointing=False, logger=None, + enable_model_summary=False, ) trainer.fit(model=baseline, train_dataloaders=dl) From 2d84fe6eee6f523712e188b4615b833d4fff3309 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 26 Jun 2024 22:17:20 +0200 Subject: [PATCH 57/57] :ok_hand: Make review modifications before merging --- auto_tutorials_source/tutorial_der_cubic.py | 3 +- .../tutorial_from_de_to_pe.py | 21 +++--- .../tutorial_mc_batch_norm.py | 3 +- auto_tutorials_source/tutorial_mc_dropout.py | 6 +- auto_tutorials_source/tutorial_scaler.py | 1 - docs/source/api.rst | 19 ++++- .../classification/mnist/bayesian_lenet.py | 62 ---------------- .../mnist/configs/bayesian_lenet.yaml | 33 ++++++--- .../mnist/configs/lenet_swamerged.yaml | 73 ------------------- tests/_dummies/datamodule.py | 10 +-- tests/datamodules/test_abstract_datamodule.py | 18 ++--- tests/layers/test_distributions.py | 6 +- .../wrappers/test_checkpoint_ensemble.py | 4 +- tests/models/wrappers/test_ema.py | 2 +- tests/models/wrappers/test_stochastic.py | 12 +-- tests/models/wrappers/test_swa.py | 28 +++---- tests/routines/test_pixel_regression.py | 10 +++ tests/routines/test_segmentation.py | 26 +++++++ .../baselines/classification/resnet.py | 15 +++- .../baselines/classification/vgg.py | 8 +- .../baselines/classification/wideresnet.py | 8 +- torch_uncertainty/baselines/regression/mlp.py | 6 +- torch_uncertainty/datamodules/abstract.py | 6 +- .../datamodules/classification/cifar10.py | 4 +- .../datamodules/classification/cifar100.py | 4 +- .../datamodules/classification/imagenet.py | 4 +- .../datamodules/classification/mnist.py | 4 +- .../classification/tiny_imagenet.py | 4 +- torch_uncertainty/datamodules/depth/base.py | 4 +- .../datamodules/segmentation/camvid.py | 4 +- .../datamodules/segmentation/cityscapes.py | 4 +- .../datamodules/segmentation/muad.py | 4 +- .../datamodules/uci_regression.py | 4 +- .../datasets/classification/not_mnist.py | 2 +- torch_uncertainty/layers/distributions.py | 8 +- torch_uncertainty/losses.py | 3 +- .../models/wrappers/checkpoint_ensemble.py | 12 +-- .../models/wrappers/deep_ensembles.py | 10 ++- torch_uncertainty/models/wrappers/ema.py | 8 +- .../models/wrappers/mc_dropout.py | 8 +- .../models/wrappers/stochastic.py | 18 ++--- torch_uncertainty/models/wrappers/swa.py | 14 ++-- torch_uncertainty/models/wrappers/swag.py | 12 +-- torch_uncertainty/optim_recipes.py | 2 +- torch_uncertainty/routines/classification.py | 16 ++-- .../routines/pixel_regression.py | 24 +++--- torch_uncertainty/routines/regression.py | 16 ++-- torch_uncertainty/routines/segmentation.py | 16 ++-- torch_uncertainty/utils/distributions.py | 1 - torch_uncertainty/utils/hub.py | 4 +- 50 files changed, 262 insertions(+), 332 deletions(-) delete mode 100644 experiments/classification/mnist/bayesian_lenet.py delete mode 100644 experiments/classification/mnist/configs/lenet_swamerged.yaml diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index 62afd3ef..96d72375 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -29,7 +29,6 @@ We also need to define an optimizer using torch.optim and the neural network utils within torch.nn. """ -# %% import torch from lightning.pytorch import Trainer from lightning import LightningDataModule @@ -49,7 +48,7 @@ def optim_regression( model: nn.Module, learning_rate: float = 5e-4, -) -> dict: +): optimizer = optim.Adam( model.parameters(), lr=learning_rate, diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py index 75f39f22..2290a024 100644 --- a/auto_tutorials_source/tutorial_from_de_to_pe.py +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -10,9 +10,10 @@ The MNIST dataset consists of 70 000 images of handwritten digits from 0 to 9. The images are grayscale and 28x28-pixel sized. The task is to classify the images into their respective digits. The dataset can be automatically downloaded using the `torchvision` library. In this notebook, we will train a model and an ensemble on this task and evaluate their performance. The performance will consist in the following metrics: -- Accuracy: the proportion of correctly classified images -- Brier score: a measure of the quality of the predicted probabilities -- Calibration error: a measure of the calibration of the predicted probabilities +- Accuracy: the proportion of correctly classified images, +- Brier score: a measure of the quality of the predicted probabilities, +- Calibration error: a measure of the calibration of the predicted probabilities, +- Negative Log-Likelihood: the value of the loss on the test set. Throughout this notebook, we abstract the training and evaluation process using [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) and [TorchUncertainty](https://torch-uncertainty.github.io/). @@ -22,11 +23,10 @@ TorchUncertainty includes datamodules that handle the data loading and preprocessing. We don't use them here for tutorial purposes. """ -# %% # 1. Download, instantiate and visualize the datasets # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The dataset is automatically downloaded using torchvision. We will then visualize a few images to get a sense of the data. +# The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with. # Create the transforms for the images import torch @@ -139,11 +139,10 @@ def optim_recipe(model, lr_mult: float = 1.0): # %% # To train the model, we use [TorchUncertainty](https://torch-uncertainty.github.io/), a library that we have developed to ease -# the training and evaluation of models with uncertainty. You can have a look at the -# [documentation](https://torch-uncertainty.github.io/) and the [code](https://github.com/ENSTA-U2IS-AI/torch-uncertainty). +# the training and evaluation of models with uncertainty. # # **Note:** To train supervised classification models we most often use the cross-entropy loss. -# With weight-decay, this minimizing this loss amounts to finding a Maximum a posteriori (MAP) estimate of the model parameters. +# With weight-decay, minimizing this loss amounts to finding a Maximum a posteriori (MAP) estimate of the model parameters. # This means that the model is trained to predict the most likely class for each input. @@ -291,7 +290,7 @@ def optim_recipe(model, lr_mult: float = 1.0): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # In the paper [Packed-Ensembles for Efficient Uncertainty Quantification](https://arxiv.org/abs/2210.09184) -# published at the International Conference on Learning Representations (ICLR) last year, we introduced a +# published at the International Conference on Learning Representations (ICLR) in 2023, we introduced a # modification of Deep Ensembles to make it more computationally-efficient. The idea is to pack the ensemble # members into a single model, which allows us to train the ensemble in a single forward pass. # This modification is particularly useful when the ensemble size is large, as it is often the case in practice. @@ -386,7 +385,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: packed_perf = trainer.test(packed_routine, dataloaders=[test_dl, ood_dl]) # %% -# The training time should be approximately similar to the one of the single model that you trained before. However, please note that we are working with very small models, hence completely underusing the your GPU. As such, the training time is not representative of what you would observe with larger models. +# The training time should be approximately similar to the one of the single model that you trained before. However, please note that we are working with very small models, hence completely underusing your GPU. As such, the training time is not representative of what you would observe with larger models. # # You can read more on Packed-Ensembles in the [paper](https://arxiv.org/abs/2210.09184) or the [Medium](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873) post. # @@ -414,4 +413,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Grouping Loss # ^^^^^^^^^^^^^ # -# The grouping loss is a measure of uncertainty orthogonal to calibration. Have a look at [this paper](https://arxiv.org/abs/2210.16315) to learn about it. Check out the small library [GLest](https://github.com/aperezlebel/glest) to learn more about it. TorchUncertainty includes a wrapper of the library to compute the grouping loss with eval_grouping_loss parameter. +# The grouping loss is a measure of uncertainty orthogonal to calibration. Have a look at [this paper](https://arxiv.org/abs/2210.16315) to learn about it. Check out their small library [GLest](https://github.com/aperezlebel/glest). TorchUncertainty includes a wrapper of the library to compute the grouping loss with eval_grouping_loss parameter. diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 140559fc..a8bed883 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -22,7 +22,6 @@ We also need import the neural network utils within `torch.nn`. """ -# %% from pathlib import Path from lightning import Trainer @@ -44,7 +43,7 @@ trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -root = Path("") / "data" +root = Path("data") datamodule = MNISTDataModule(root, batch_size=128) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 2d5bf925..e19eee61 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -51,14 +51,12 @@ # dataloaders and transforms. We create the model using the # blueprint from torch_uncertainty.models and we wrap it into mc_dropout. # -# It is important to specify the arguments,``num_estimators`` -# and the ``dropout_rate`` -# to use Monte Carlo dropout. +# It is important to add a ``dropout_rate`` argument in your model to use Monte Carlo dropout. trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -root = Path("") / "data" +root = Path("data") datamodule = MNISTDataModule(root=root, batch_size=128) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index ceaaa036..fdbfc469 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -25,7 +25,6 @@ If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag. """ -# %% from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.metrics import CalibrationError from torch_uncertainty.models.resnet import resnet diff --git a/docs/source/api.rst b/docs/source/api.rst index 4eafd415..ed5e07ce 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -156,19 +156,32 @@ Models Wrappers ^^^^^^^^ + + +Functions +""""""""" + .. autosummary:: :toctree: generated/ :nosignatures: - :template: class.rst deep_ensembles + mc_dropout + +Classes +""""""" + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + CheckpointEnsemble EMA + MCDropout StochasticModel SWA SWAG - MCDropout - mc_dropout Metrics ------- diff --git a/experiments/classification/mnist/bayesian_lenet.py b/experiments/classification/mnist/bayesian_lenet.py deleted file mode 100644 index 05a7c17e..00000000 --- a/experiments/classification/mnist/bayesian_lenet.py +++ /dev/null @@ -1,62 +0,0 @@ -from functools import partial -from pathlib import Path - -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.losses import ELBOLoss -from torch_uncertainty.models.lenet import bayesian_lenet -from torch_uncertainty.routines.classification import ClassificationSingle - - -def optim_lenet(model: nn.Module) -> dict: - """Optimization recipe for LeNet. - - Uses Adam default hyperparameters. - - Args: - model (nn.Module): LeNet model. - """ - optimizer = optim.Adam( - model.parameters(), - lr=1e-3, - ) - return {"optimizer": optimizer} - - -if __name__ == "__main__": - args = init_args(datamodule=MNISTDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - net_name = "bayesian-lenet-mnist" - - # datamodule - args.root = str(root / "data") - dm = MNISTDataModule(**vars(args)) - - # model - model = bayesian_lenet(dm.num_channels, dm.num_classes) - - # Here, the loss is a bit more complicated - # hyperparameters are from blitz. - loss = partial( - ELBOLoss, - inner_loss=nn.CrossEntropyLoss(), - kl_weight=1 / 50000, - num_samples=3, - ) - - baseline = ClassificationSingle( - model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=loss, - optim_recipe=optim_lenet, - **vars(args), - ) - - cli_main(baseline, dm, "logs/", net_name, args) diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 0c7989ab..55f6b3c6 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -28,20 +28,29 @@ trainer: check_finite: true model: model: - class_path: torch_uncertainty.models.lenet._LeNet + class_path: torch_uncertainty.models.StochasticModel init_args: - in_channels: 1 - num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d - activation: torch.nn.ReLU - norm: torch.nn.Identity - groups: 1 - dropout_rate: 0 - last_layer_dropout: false - layer_args: {} + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + num_samples: 16 num_classes: 10 - loss: CrossEntropyLoss + loss: + class_path: torch_uncertainty.losses.ELBOLoss + init_args: + kl_weight: 0.00002 + inner_loss: torch.nn.CrossEntropyLoss + num_samples: 3 data: root: ./data batch_size: 128 diff --git a/experiments/classification/mnist/configs/lenet_swamerged.yaml b/experiments/classification/mnist/configs/lenet_swamerged.yaml deleted file mode 100644 index 29a06780..00000000 --- a/experiments/classification/mnist/configs/lenet_swamerged.yaml +++ /dev/null @@ -1,73 +0,0 @@ -# lightning.pytorch==2.1.3 -seed_everything: false -eval_after_fit: true -trainer: - accelerator: gpu - devices: 1 - precision: 16-mixed - max_epochs: 75 - logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: logs/lenet_swa - name: standard - default_hp_metric: false - callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val/cls/Acc - mode: max - save_last: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/cls/Acc - patience: 1000 - check_finite: true -model: - model: - class_path: torch_uncertainty.models.wrappers.SWA - init_args: - model: - class_path: torch_uncertainty.models.lenet._LeNet - init_args: - in_channels: 1 - num_classes: 10 - linear_layer: torch.nn.Linear - conv2d_layer: torch.nn.Conv2d - activation: torch.nn.ReLU - norm: torch.nn.Identity - groups: 1 - dropout_rate: 0 - last_layer_dropout: false - layer_args: {} - cycle_start: 19 - cycle_length: 5 - num_classes: 10 - loss: CrossEntropyLoss - is_ensemble: true -data: - root: ./data - batch_size: 128 -optimizer: - lr: 0.05 - momentum: 0.9 - weight_decay: 5e-4 - nesterov: true -lr_scheduler: - class_path: torch.optim.lr_scheduler.SequentialLR - init_args: - milestones: - - 20 - schedulers: - - class_path: torch.optim.lr_scheduler.StepLR - init_args: - step_size: 5 - gamma: 0.1 - - class_path: torch.optim.lr_scheduler.StepLR - init_args: - step_size: 5 - gamma: 0.1 - diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 9675dd19..9dc59dfd 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from torchvision import tv_tensors -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from .dataset import ( DummPixelRegressionDataset, @@ -17,7 +17,7 @@ ) -class DummyClassificationDataModule(BaseDataModule): +class DummyClassificationDataModule(TUDataModule): num_channels = 1 image_size: int = 4 training_task = "classification" @@ -104,7 +104,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyRegressionDataModule(BaseDataModule): +class DummyRegressionDataModule(TUDataModule): in_features = 4 training_task = "regression" @@ -160,7 +160,7 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]: return [self._data_loader(self.test)] -class DummySegmentationDataModule(BaseDataModule): +class DummySegmentationDataModule(TUDataModule): num_channels = 3 training_task = "segmentation" @@ -249,7 +249,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyPixelRegressionDataModule(BaseDataModule): +class DummyPixelRegressionDataModule(TUDataModule): num_channels = 3 training_task = "pixel_regression" diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 25bb1952..0f0bd64f 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -4,17 +4,17 @@ from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.datamodules.abstract import ( - BaseDataModule, CrossValDataModule, + TUDataModule, ) -class TestBaseDataModule: - """Testing the BaseDataModule class.""" +class TestTUDataModule: + """Testing the TUDataModule class.""" def test_errors(self): - BaseDataModule.__abstractmethods__ = set() - dm = BaseDataModule("root", 128, 0.0, 4, True, True) + TUDataModule.__abstractmethods__ = set() + dm = TUDataModule("root", 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() dm._get_train_data() @@ -25,8 +25,8 @@ class TestCrossValDataModule: """Testing the CrossValDataModule class.""" def test_cv_main(self): - BaseDataModule.__abstractmethods__ = set() - dm = BaseDataModule("root", 128, 0.0, 4, True, True) + TUDataModule.__abstractmethods__ = set() + dm = TUDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds @@ -48,8 +48,8 @@ def test_cv_main(self): cv_dm.test_dataloader() def test_errors(self): - BaseDataModule.__abstractmethods__ = set() - dm = BaseDataModule("root", 128, 0.0, 4, True, True) + TUDataModule.__abstractmethods__ = set() + dm = TUDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py index 85a67c6f..63a52f27 100644 --- a/tests/layers/test_distributions.py +++ b/tests/layers/test_distributions.py @@ -1,16 +1,16 @@ import pytest from torch_uncertainty.layers.distributions import ( - AbstractDist, LaplaceLayer, NormalLayer, + TUDist, ) class TestDistributions: def test(self): - AbstractDist.__abstractmethods__ = set() - dist = AbstractDist(dim=1) + TUDist.__abstractmethods__ = set() + dist = TUDist(dim=1) dist.forward(None) def test_errors(self): diff --git a/tests/models/wrappers/test_checkpoint_ensemble.py b/tests/models/wrappers/test_checkpoint_ensemble.py index 239d13af..b159160c 100644 --- a/tests/models/wrappers/test_checkpoint_ensemble.py +++ b/tests/models/wrappers/test_checkpoint_ensemble.py @@ -14,13 +14,13 @@ def test_training(self): ens.train() ens(torch.randn(1, 1)) - ens.update_model(0) + ens.update_wrapper(0) ens.eval() ens(torch.randn(1, 1)) ens = CheckpointEnsemble(dummy_model(1, 10), use_final_checkpoint=False) ens.train() ens(torch.randn(1, 1)) - ens.update_model(0) + ens.update_wrapper(0) ens.eval() ens(torch.randn(1, 1)) diff --git a/tests/models/wrappers/test_ema.py b/tests/models/wrappers/test_ema.py index 52b81210..3be66e3e 100644 --- a/tests/models/wrappers/test_ema.py +++ b/tests/models/wrappers/test_ema.py @@ -14,7 +14,7 @@ def test_training(self): ema.eval() ema(torch.randn(1, 1)) ema.train() - ema.update_model(0) + ema.update_wrapper(0) def test_failures(self): with pytest.raises(ValueError, match="must be in the range"): diff --git a/tests/models/wrappers/test_stochastic.py b/tests/models/wrappers/test_stochastic.py index 76b60332..dc8a814e 100644 --- a/tests/models/wrappers/test_stochastic.py +++ b/tests/models/wrappers/test_stochastic.py @@ -48,24 +48,24 @@ def test_main(self): model = StochasticModel(DummyModelLinear(), 2) model.freeze() model(torch.randn(1, 1)) - assert model.model.layer.frozen + assert model.core_model.layer.frozen model.unfreeze() - assert not model.model.layer.frozen + assert not model.core_model.layer.frozen model.eval() model(torch.randn(1, 1)) model = StochasticModel(DummyModelConv(), 2) model.freeze() - assert model.model.layer.frozen + assert model.core_model.layer.frozen model.unfreeze() - assert not model.model.layer.frozen + assert not model.core_model.layer.frozen def test_mix(self): model = StochasticModel(DummyModelMix(), 2) model.freeze() - assert model.model.layer.frozen + assert model.core_model.layer.frozen model.unfreeze() - assert not model.model.layer.frozen + assert not model.core_model.layer.frozen state = model.sample()[0] keys = state.keys() diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py index 2e693df2..f590473b 100644 --- a/tests/models/wrappers/test_swa.py +++ b/tests/models/wrappers/test_swa.py @@ -18,11 +18,11 @@ def test_training(self): swa.train() swa(torch.randn(1, 1)) - swa.update_model(0) - swa.update_bn(dl, "cpu") + swa.update_wrapper(0) + swa.bn_update(dl, "cpu") - swa.update_model(1) - swa.update_bn(dl, "cpu") + swa.update_wrapper(1) + swa.bn_update(dl, "cpu") swa.eval() swa(torch.randn(1, 1)) @@ -56,36 +56,36 @@ def test_training(self): swag.train() swag(torch.randn(1, 1)) - swag.update_model(0) + swag.update_wrapper(0) assert swag.swag_stats[ "model.swag_stats.linear.weight_covariance_sqrt" ].shape == (0, 10) - swag.update_bn(dl, "cpu") + swag.bn_update(dl, "cpu") swag(torch.randn(1, 1)) - swag.update_model(1) + swag.update_wrapper(1) assert swag.swag_stats[ "model.swag_stats.linear.weight_covariance_sqrt" ].shape == (0, 10) assert swag.num_avgd_models == 0 - swag.update_bn(dl, "cpu") + swag.bn_update(dl, "cpu") - swag.update_model(2) + swag.update_wrapper(2) assert swag.swag_stats[ "model.swag_stats.linear.weight_covariance_sqrt" ].shape == (1, 10) - swag.update_bn(dl, "cpu") + swag.bn_update(dl, "cpu") swag(torch.randn(1, 1)) - swag.update_model(3) + swag.update_wrapper(3) assert swag.swag_stats[ "model.swag_stats.linear.weight_covariance_sqrt" ].shape == (2, 10) - swag.update_model(4) + swag.update_wrapper(4) assert swag.num_avgd_models == 3 assert swag.swag_stats[ "model.swag_stats.linear.weight_covariance_sqrt" ].shape == (3, 10) - swag.update_model(5) + swag.update_wrapper(5) assert swag.num_avgd_models == 4 assert swag.swag_stats[ "model.swag_stats.linear.weight_covariance_sqrt" @@ -100,7 +100,7 @@ def test_training(self): diag_covariance=True, ) swag.train() - swag.update_model(2) + swag.update_wrapper(2) swag.sample(1, True, False, seed=1) def test_state_dict(self): diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index 44055669..56e2058d 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -119,3 +119,13 @@ def test_depth_errors(self): output_dim=0, loss=nn.MSELoss(), ) + + with pytest.raises(ValueError, match="num_image_plot must be positive"): + PixelRegressionRoutine( + probabilistic=False, + model=nn.Identity(), + output_dim=1, + loss=nn.MSELoss(), + num_image_plot=0, + log_plots=True, + ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 0a3e822b..7168e607 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -35,6 +35,32 @@ def test_one_estimator_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer( + accelerator="cpu", + max_epochs=2, + logger=None, + enable_checkpointing=False, + ) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) + + model = DummySegmentationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + image_size=dm.image_size, + loss=nn.CrossEntropyLoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + log_plots=True, + swa=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + def test_two_estimators_two_classes(self): trainer = TUTrainer( accelerator="cpu", diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 9cfd4bf7..d184cda0 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -14,10 +14,17 @@ from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget +ENSEMBLE_METHODS = [ + "packed", + "batched", + "lpbnn", + "masked", + "mc-dropout", + "mimo", +] + class ResNetBaseline(ClassificationRoutine): - single = ["std"] - ensemble = ["packed", "batched", "lpbnn", "masked", "mc-dropout", "mimo"] versions = { "std": resnet, "packed": packed_resnet, @@ -171,7 +178,7 @@ def __init__( if version not in self.versions: raise ValueError(f"Unknown version: {version}") - if version in self.ensemble: + if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } @@ -213,7 +220,7 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - is_ensemble=version in self.ensemble, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 0f3a5a59..7375a082 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -10,10 +10,10 @@ from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget +ENSEMBLE_METHODS = ["mc-dropout", "packed"] + class VGGBaseline(ClassificationRoutine): - single = ["std"] - ensemble = ["mc-dropout", "packed"] versions = { "std": vgg, "mc-dropout": vgg, @@ -134,7 +134,7 @@ def __init__( "num_estimators": num_estimators, } - if version in self.ensemble: + if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } @@ -162,7 +162,7 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - is_ensemble=version in self.ensemble, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index b2f5e82d..78abe960 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -15,10 +15,10 @@ ) from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget +ENSEMBLE_METHODS = ["packed", "batched", "masked", "mimo", "mc-dropout"] + class WideResNetBaseline(ClassificationRoutine): - single = ["std"] - ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] versions = { "std": [wideresnet28x10], "mc-dropout": [wideresnet28x10], @@ -142,7 +142,7 @@ def __init__( if version not in self.versions: raise ValueError(f"Unknown version: {version}") - if version in self.ensemble: + if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } @@ -184,7 +184,7 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - is_ensemble=version in self.ensemble, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 7f5bb975..34cfdc21 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -13,10 +13,10 @@ ) from torch_uncertainty.transforms.batch import RepeatTarget +ENSEMBLE_METHODS = ["packed"] + class MLPBaseline(RegressionRoutine): - single = ["std"] - ensemble = ["packed"] versions = {"std": mlp, "packed": packed_mlp} def __init__( @@ -82,7 +82,7 @@ def __init__( output_dim=output_dim, model=model, loss=loss, - is_ensemble=version in self.ensemble, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index d5a9020f..16308d72 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -9,7 +9,7 @@ from torch.utils.data.sampler import SubsetRandomSampler -class BaseDataModule(ABC, LightningDataModule): +class TUDataModule(ABC, LightningDataModule): training_task: str train: Dataset val: Dataset @@ -150,13 +150,13 @@ def make_cross_val_splits( return cv_dm -class CrossValDataModule(BaseDataModule): +class CrossValDataModule(TUDataModule): def __init__( self, root: str | Path, train_idx: ArrayLike, val_idx: ArrayLike, - datamodule: BaseDataModule, + datamodule: TUDataModule, batch_size: int, val_split: float, num_workers: int, diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 65d58160..1e5eda4a 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -9,14 +9,14 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class CIFAR10DataModule(BaseDataModule): +class CIFAR10DataModule(TUDataModule): num_classes = 10 num_channels = 3 input_shape = (3, 32, 32) diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index e6267cde..373430bd 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -10,14 +10,14 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class CIFAR100DataModule(BaseDataModule): +class CIFAR100DataModule(TUDataModule): num_classes = 100 num_channels = 3 input_shape = (3, 32, 32) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d1321a3f..6d35303c 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.classification import ( ImageNetA, ImageNetO, @@ -23,7 +23,7 @@ ) -class ImageNetDataModule(BaseDataModule): +class ImageNetDataModule(TUDataModule): num_classes = 1000 num_channels = 3 test_datasets = ["r", "o", "a"] diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a71835c2..b411f502 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -6,13 +6,13 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class MNISTDataModule(BaseDataModule): +class MNISTDataModule(TUDataModule): num_classes = 10 num_channels = 1 input_shape = (1, 28, 28) diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 10aab380..49506d48 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -9,7 +9,7 @@ from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet from torch_uncertainty.utils import ( create_train_val_split, @@ -17,7 +17,7 @@ ) -class TinyImageNetDataModule(BaseDataModule): +class TinyImageNetDataModule(TUDataModule): num_classes = 200 num_channels = 3 training_task = "classification" diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 2b38405d..34c69c89 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -7,12 +7,12 @@ from torchvision.datasets import VisionDataset from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class DepthDataModule(BaseDataModule): +class DepthDataModule(TUDataModule): def __init__( self, dataset: type[VisionDataset], diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index afa5a6f2..84f99ac7 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -4,11 +4,11 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.segmentation import CamVid -class CamVidDataModule(BaseDataModule): +class CamVidDataModule(TUDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index 03a923a1..baee3d4b 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -6,13 +6,13 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.segmentation import Cityscapes from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class CityscapesDataModule(BaseDataModule): +class CityscapesDataModule(TUDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 43b5e44c..9ba10ee4 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -6,13 +6,13 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import BaseDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class MUADDataModule(BaseDataModule): +class MUADDataModule(TUDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 2caea314..a5cbe8af 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -6,10 +6,10 @@ from torch_uncertainty.datasets.regression import UCIRegression -from .abstract import BaseDataModule +from .abstract import TUDataModule -class UCIDataModule(BaseDataModule): +class UCIDataModule(TUDataModule): training_task = "regression" def __init__( diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index e15ac640..9bd27f8c 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -66,7 +66,7 @@ def __init__( ) super().__init__( - self.root / ("notMNIST_" + subset), + self.root / f"notMNIST_{subset}", transform=transform, target_transform=target_transform, ) diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 956fb537..4c7829c7 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -7,7 +7,7 @@ from torch_uncertainty.utils.distributions import NormalInverseGamma -class AbstractDist(ABC, nn.Module): +class TUDist(ABC, nn.Module): def __init__(self, dim: int) -> None: super().__init__() if dim < 1: @@ -19,7 +19,7 @@ def forward(self, x: Tensor) -> Distribution: pass -class NormalLayer(AbstractDist): +class NormalLayer(TUDist): """Normal distribution layer. Converts model outputs to Independent Normal distributions. @@ -49,7 +49,7 @@ def forward(self, x: Tensor) -> Normal: return Normal(loc, scale) -class LaplaceLayer(AbstractDist): +class LaplaceLayer(TUDist): """Laplace distribution layer. Converts model outputs to Independent Laplace distributions. @@ -79,7 +79,7 @@ def forward(self, x: Tensor) -> Laplace: return Laplace(loc, scale) -class NormalInverseGammaLayer(AbstractDist): +class NormalInverseGammaLayer(TUDist): """Normal-Inverse-Gamma distribution layer. Converts model outputs to Independent Normal-Inverse-Gamma distributions. diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 45e56fb3..c82ab210 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -121,7 +121,8 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: for _ in range(self.num_samples): logits = self.model(inputs) aggregated_elbo += self.inner_loss(logits, targets) - aggregated_elbo += self.kl_weight * self._kl_div() + # TODO: This shouldn't be necessary + aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device) return aggregated_elbo / self.num_samples def set_model(self, model: nn.Module | None) -> None: diff --git a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py index 5b446882..f2d4d869 100644 --- a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py +++ b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py @@ -26,7 +26,7 @@ def __init__( Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018. """ super().__init__() - self.model = model + self.core_model = model self.save_schedule = save_schedule self.use_final_checkpoint = use_final_checkpoint self.num_estimators = int(use_final_checkpoint) @@ -34,14 +34,14 @@ def __init__( self.num_estimators = 1 @torch.no_grad() - def update_model(self, epoch: int) -> None: + def update_wrapper(self, epoch: int) -> None: """Save the model at the given epoch if included in the schedule. Args: epoch (int): The current epoch. """ if self.save_schedule is None or epoch in self.save_schedule: - self.saved_models.append(copy.deepcopy(self.model)) + self.saved_models.append(copy.deepcopy(self.core_model)) self.num_estimators += 1 def eval_forward(self, x: torch.Tensor) -> torch.Tensor: @@ -58,16 +58,16 @@ def eval_forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: The model or ensemble output. """ if not len(self.saved_models): - return self.model.forward(x) + return self.core_model.forward(x) preds = torch.cat( [model.forward(x) for model in self.saved_models], dim=0 ) if self.use_final_checkpoint: - model_forward = self.model.forward(x) + model_forward = self.core_model.forward(x) preds = torch.cat([model_forward, preds], dim=0) return preds def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training: - return self.model.forward(x) + return self.core_model.forward(x) return self.eval_forward(x) diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 807ed0ed..a72ae7c4 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -15,7 +15,7 @@ def __init__( ) -> None: """Create a classification deep ensembles from a list of models.""" super().__init__() - self.models = nn.ModuleList(models) + self.core_models = nn.ModuleList(models) self.num_estimators = len(models) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -29,7 +29,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: where :math:`B` is the batch size, :math:`N` is the number of estimators, and :math:`C` is the number of classes. """ - return torch.cat([model.forward(x) for model in self.models], dim=0) + return torch.cat( + [model.forward(x) for model in self.core_models], dim=0 + ) class _RegDeepEnsembles(_DeepEnsembles): @@ -52,7 +54,9 @@ def forward(self, x: torch.Tensor) -> Distribution: Distribution: """ if self.probabilistic: - return cat_dist([model.forward(x) for model in self.models], dim=0) + return cat_dist( + [model.forward(x) for model in self.core_models], dim=0 + ) return super().forward(x) diff --git a/torch_uncertainty/models/wrappers/ema.py b/torch_uncertainty/models/wrappers/ema.py index 253bbd85..386fcca7 100644 --- a/torch_uncertainty/models/wrappers/ema.py +++ b/torch_uncertainty/models/wrappers/ema.py @@ -17,12 +17,12 @@ def __init__( """ super().__init__() _ema_checks(momentum) - self.model = model + self.core_model = model self.ema_model = copy.deepcopy(model) self.momentum = momentum self.remainder = 1 - momentum - def update_model(self, epoch: int | None = None) -> None: + def update_wrapper(self, epoch: int | None = None) -> None: """Update the EMA model. Args: @@ -30,7 +30,7 @@ def update_model(self, epoch: int | None = None) -> None: """ for ema_param, param in zip( self.ema_model.parameters(), - self.model.parameters(), + self.core_model.parameters(), strict=False, ): ema_param.data = ( @@ -42,7 +42,7 @@ def eval_forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor: if self.training: - return self.model.forward(x) + return self.core_model.forward(x) return self.eval_forward(x) diff --git a/torch_uncertainty/models/wrappers/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py index 77c591c9..986d23a8 100644 --- a/torch_uncertainty/models/wrappers/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -36,7 +36,7 @@ def __init__( _dropout_checks(model, num_estimators) self.last_layer = last_layer self.on_batch = on_batch - self.model = model + self.core_model = model self.num_estimators = num_estimators self.filtered_modules = list( @@ -70,13 +70,13 @@ def forward( x: Tensor, ) -> Tensor: if self.training: - return self.model(x) + return self.core_model(x) if self.on_batch: x = x.repeat(self.num_estimators, 1, 1, 1) - return self.model(x) + return self.core_model(x) # Else, for loop return torch.cat( - [self.model(x) for _ in range(self.num_estimators)], dim=0 + [self.core_model(x) for _ in range(self.num_estimators)], dim=0 ) diff --git a/torch_uncertainty/models/wrappers/stochastic.py b/torch_uncertainty/models/wrappers/stochastic.py index 147b97cb..7f298a87 100644 --- a/torch_uncertainty/models/wrappers/stochastic.py +++ b/torch_uncertainty/models/wrappers/stochastic.py @@ -7,23 +7,23 @@ class StochasticModel(nn.Module): def __init__(self, model: nn.Module, num_samples: int) -> None: super().__init__() - self.model = model + self.core_model = model self.num_samples = num_samples def eval_forward(self, x: Tensor) -> Tensor: return torch.cat( - [self.model.forward(x) for _ in range(self.num_samples)], dim=0 + [self.core_model.forward(x) for _ in range(self.num_samples)], dim=0 ) def forward(self, x: Tensor) -> Tensor: if self.training: - return self.model.forward(x) + return self.core_model.forward(x) return self.eval_forward(x) def sample(self, num_samples: int = 1) -> list[dict]: sampled_models = [{}] * num_samples - for module_name in self.model._modules: - module = self.model._modules[module_name] + for module_name in self.core_model._modules: + module = self.core_model._modules[module_name] if isinstance(module, bayesian_modules): for model in sampled_models: weight, bias = module.sample() @@ -43,13 +43,13 @@ def sample(self, num_samples: int = 1) -> list[dict]: return sampled_models def freeze(self) -> None: - for module_name in self.model._modules: - module = self.model._modules[module_name] + for module_name in self.core_model._modules: + module = self.core_model._modules[module_name] if isinstance(module, bayesian_modules): module.freeze() def unfreeze(self) -> None: - for module_name in self.model._modules: - module = self.model._modules[module_name] + for module_name in self.core_model._modules: + module = self.core_model._modules[module_name] if isinstance(module, bayesian_modules): module.unfreeze() diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py index 5125b754..27fbb20e 100644 --- a/torch_uncertainty/models/wrappers/swa.py +++ b/torch_uncertainty/models/wrappers/swa.py @@ -32,7 +32,7 @@ def __init__( """ super().__init__() _swa_checks(cycle_start, cycle_length) - self.model = model + self.core_model = model self.cycle_start = cycle_start self.cycle_length = cycle_length @@ -41,18 +41,18 @@ def __init__( self.need_bn_update = False @torch.no_grad() - def update_model(self, epoch: int) -> None: + def update_wrapper(self, epoch: int) -> None: if ( epoch >= self.cycle_start and (epoch - self.cycle_start) % self.cycle_length == 0 ): if self.swa_model is None: - self.swa_model = copy.deepcopy(self.model) + self.swa_model = copy.deepcopy(self.core_model) self.num_avgd_models = torch.tensor(1) else: for swa_param, param in zip( self.swa_model.parameters(), - self.model.parameters(), + self.core_model.parameters(), strict=False, ): swa_param.data += (param.data - swa_param.data) / ( @@ -63,15 +63,15 @@ def update_model(self, epoch: int) -> None: def eval_forward(self, x: Tensor) -> Tensor: if self.swa_model is None: - return self.model.forward(x) + return self.core_model.forward(x) return self.swa_model.forward(x) def forward(self, x: Tensor) -> Tensor: if self.training: - return self.model.forward(x) + return self.core_model.forward(x) return self.eval_forward(x) - def update_bn(self, loader: DataLoader, device) -> None: + def bn_update(self, loader: DataLoader, device) -> None: if self.need_bn_update and self.swa_model is not None: torch.optim.swa_utils.update_bn( loader, self.swa_model, device=device diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py index 7670979a..fb12588c 100644 --- a/torch_uncertainty/models/wrappers/swag.py +++ b/torch_uncertainty/models/wrappers/swag.py @@ -67,13 +67,13 @@ def __init__( def eval_forward(self, x: torch.Tensor) -> torch.Tensor: if not self.fit: - return self.model.forward(x) + return self.core_model.forward(x) return torch.cat([mod.to(device=x.device)(x) for mod in self.samples]) def initialize_stats(self) -> None: """Initialize the SWAG dictionary of statistics.""" self.swag_stats = {} - for name_p, param in self.model.named_parameters(): + for name_p, param in self.core_model.named_parameters(): mean, squared_mean = ( torch.zeros_like(param, device="cpu"), torch.zeros_like(param, device="cpu"), @@ -88,7 +88,7 @@ def initialize_stats(self) -> None: ) @torch.no_grad() - def update_model(self, epoch: int) -> None: + def update_wrapper(self, epoch: int) -> None: """Update the SWAG posterior. The update is performed if the epoch is greater than the cycle start @@ -104,7 +104,7 @@ def update_model(self, epoch: int) -> None: ): return - for name_p, param in self.model.named_parameters(): + for name_p, param in self.core_model.named_parameters(): mean = self.swag_stats[self.prfx + name_p + "_mean"] squared_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] new_param = param.data.detach().cpu() @@ -140,7 +140,7 @@ def update_model(self, epoch: int) -> None: self.need_bn_update = True self.fit = True - def update_bn(self, loader: DataLoader, device) -> None: + def bn_update(self, loader: DataLoader, device) -> None: """Update the bachnorm statistics of the current SWAG samples. Args: @@ -189,7 +189,7 @@ def sample( def _fullrank_sample( self, scale: float, diagonal_covariance: bool ) -> nn.Module: - new_sample = copy.deepcopy(self.model) + new_sample = copy.deepcopy(self.core_model) for name_p, param in new_sample.named_parameters(): mean = self.swag_stats[self.prfx + name_p + "_mean"] diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index b6267d31..8d648400 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -26,7 +26,7 @@ def optim_abnn( ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer, - milestones=[2, 4], + milestones=[1, 4], gamma=0.1, ) return {"optimizer": optimizer, "lr_scheduler": scheduler} diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 814ab28d..2d45976a 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -150,8 +150,8 @@ def __init__( self.save_in_csv = save_in_csv self.calibration_set = calibration_set self.binary_cls = num_classes == 1 - self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) - self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) self.num_calibration_bins = num_calibration_bins self.model = model self.loss = loss @@ -321,10 +321,10 @@ def on_train_start(self) -> None: ) def on_validation_start(self) -> None: - if self.need_epoch_update and not self.trainer.sanity_checking: - self.model.update_model(self.current_epoch) + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) @@ -343,7 +343,7 @@ def on_test_start(self) -> None: self.ood_logit_storage = [] if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) @@ -388,8 +388,8 @@ def training_step( loss = self.loss(logits, target) else: loss = self.loss(logits, target, self.current_epoch) - if self.need_step_update: - self.model.update_model(self.current_epoch) + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 7181d276..b9762ffc 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -77,7 +77,7 @@ def __init__( metrics. Defaults to ``False``. """ super().__init__() - _depth_routine_checks(output_dim, num_image_plot) + _depth_routine_checks(output_dim, num_image_plot, log_plots) self.model = model self.output_dim = output_dim @@ -88,8 +88,8 @@ def __init__( self.is_ensemble = is_ensemble self.log_plots = log_plots - self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) - self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -134,16 +134,16 @@ def on_train_start(self) -> None: ) def on_validation_start(self) -> None: - if self.need_epoch_update and not self.trainer.sanity_checking: - self.model.update_model(self.current_epoch) + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) @@ -189,8 +189,8 @@ def training_step( else: loss = self.loss(dists[padding_mask], target[padding_mask]) - if self.need_step_update: - self.model.update_model(self.current_epoch) + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss @@ -364,10 +364,12 @@ def colorize( return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 -def _depth_routine_checks(output_dim: int, num_image_plot) -> None: +def _depth_routine_checks( + output_dim: int, num_image_plot: int, log_plots: bool +) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") - if num_image_plot < 1: + if num_image_plot < 1 and log_plots: raise ValueError( f"num_image_plot must be positive, got {num_image_plot}." ) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 9078ee02..2beeb435 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -72,8 +72,8 @@ def __init__( self.output_dim = output_dim self.loss = loss self.is_ensemble = is_ensemble - self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) - self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -112,16 +112,16 @@ def on_train_start(self) -> None: ) def on_validation_start(self) -> None: - if self.need_epoch_update and not self.trainer.sanity_checking: - self.model.update_model(self.current_epoch) + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) @@ -160,8 +160,8 @@ def training_step( targets = targets.unsqueeze(-1) loss = self.loss(dists, targets) - if self.need_step_update: - self.model.update_model(self.current_epoch) + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 9504e518..f3ece492 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -68,8 +68,8 @@ def __init__( self.model = model self.num_classes = num_classes self.loss = loss - self.need_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) - self.need_step_update = isinstance(model, STEP_UPDATE_MODEL) + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -133,16 +133,16 @@ def on_train_start(self) -> None: self.logger.log_hyperparams(self.hparams) def on_validation_start(self) -> None: - if self.need_epoch_update and not self.trainer.sanity_checking: - self.model.update_model(self.current_epoch) + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): - self.model.update_bn( + self.model.bn_update( self.trainer.train_dataloader, device=self.device ) @@ -159,8 +159,8 @@ def training_step( target = target.flatten() valid_mask = target != 255 loss = self.loss(logits[valid_mask], target[valid_mask]) - if self.need_step_update: - self.model.update_model(self.current_epoch) + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index b5d0be87..1bf2e669 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -106,7 +106,6 @@ def dist_rearrange( distribution: Distribution, pattern: str, **axes_lengths: int ) -> Distribution: dist_type = type(distribution) - print(dist_type) if isinstance(distribution, Normal | Laplace): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) diff --git a/torch_uncertainty/utils/hub.py b/torch_uncertainty/utils/hub.py index 9d86a22a..acb7e3f5 100644 --- a/torch_uncertainty/utils/hub.py +++ b/torch_uncertainty/utils/hub.py @@ -9,7 +9,7 @@ def load_hf( weight_id: str, version: int = 0 -) -> tuple[dict[str, torch.Tensor], dict]: +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: """Load a model from the HuggingFace hub. Args: @@ -17,7 +17,7 @@ def load_hf( version (int): The id of the version when there are several on HF. Returns: - Tuple[dict, dict]: The model weights and config. + tuple[dict[str, torch.Tensor], dict[str, str]]: The model weights and config. Note - License: TorchUncertainty's weights are released under the Apache 2.0 license.