From 953cc972d006e1cbcc4708004f2e75459e79694b Mon Sep 17 00:00:00 2001 From: alafage Date: Sun, 28 May 2023 00:28:53 +0200 Subject: [PATCH 01/23] Update Classification routine to be no more abstract :hammer: --- torch_uncertainty/routines/classification.py | 25 ++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 6a7ff824..586314cd 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,6 +1,6 @@ # fmt: off from argparse import ArgumentParser, Namespace -from typing import List, Tuple, Union +from typing import Any, List, Tuple, Union import pytorch_lightning as pl import torch @@ -48,6 +48,9 @@ class ClassificationSingle(pl.LightningModule): def __init__( self, num_classes: int, + model: nn.Module, + loss: nn.Module, + optimization_procedure: Any, use_entropy: bool = False, use_logits: bool = False, **kwargs, @@ -61,6 +64,12 @@ def __init__( self.use_logits = use_logits self.use_entropy = use_entropy + # model + self.model = model + # loss + self.loss = loss + # optimization procedure + self.optimization_procedure = optimization_procedure # metrics cls_metrics = MetricCollection( { @@ -91,12 +100,15 @@ def __init__( self.test_entropy_id = Entropy() self.test_entropy_ood = Entropy() + def configure_optimizers(self) -> Any: + return self.optimization_procedure(self) + @property def criterion(self) -> nn.Module: - raise NotImplementedError() + return self.loss() def forward(self, input: torch.Tensor) -> torch.Tensor: - raise NotImplementedError() + return self.model.forward(input) def on_train_start(self) -> None: # hyperparameters for performances @@ -234,6 +246,9 @@ class ClassificationEnsemble(ClassificationSingle): def __init__( self, num_classes: int, + model: nn.Module, + loss: nn.Module, + optimization_procedure: Any, num_estimators: int, use_entropy: bool = False, use_logits: bool = False, @@ -243,6 +258,9 @@ def __init__( ) -> None: super().__init__( num_classes=num_classes, + model=model, + loss=loss, + optimization_procedure=optimization_procedure, use_entropy=use_entropy, use_logits=use_logits, **kwargs, @@ -311,7 +329,6 @@ def validation_step( # type: ignore ) -> None: inputs, targets = batch logits = self.forward(inputs) - # logits = logits.reshape(self.num_estimators, -1, logits.size(-1)) logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) probs_per_est = F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) From 2c38c67c0975698be27db431ec368f086d74b032 Mon Sep 17 00:00:00 2001 From: alafage Date: Sun, 28 May 2023 00:31:17 +0200 Subject: [PATCH 02/23] Add unique baseline for ResNet :hammer: --- torch_uncertainty/baselines/__init__.py | 9 +- .../baselines/classification/__init__.py | 0 .../baselines/classification/resnet.py | 225 ++++++++++++++++++ 3 files changed, 230 insertions(+), 4 deletions(-) create mode 100644 torch_uncertainty/baselines/classification/__init__.py create mode 100644 torch_uncertainty/baselines/classification/resnet.py diff --git a/torch_uncertainty/baselines/__init__.py b/torch_uncertainty/baselines/__init__.py index 732fd24a..71790374 100644 --- a/torch_uncertainty/baselines/__init__.py +++ b/torch_uncertainty/baselines/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa -from .batched import BatchedResNet, BatchedWideResNet -from .masked import MaskedResNet, MaskedWideResNet -from .packed import PackedResNet, PackedWideResNet -from .standard import ResNet, WideResNet +# from .batched import BatchedResNet, BatchedWideResNet +# from .masked import MaskedResNet, MaskedWideResNet +# from .packed import PackedResNet, PackedWideResNet +# from .standard import ResNet, WideResNet +from .classification.resnet import ResNet diff --git a/torch_uncertainty/baselines/classification/__init__.py b/torch_uncertainty/baselines/classification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py new file mode 100644 index 00000000..4541d65a --- /dev/null +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -0,0 +1,225 @@ +# fmt: off +from argparse import ArgumentParser, BooleanOptionalAction +from typing import Any, Literal, Optional + +import torch.nn as nn +from pytorch_lightning import LightningModule + +from torch_uncertainty.models.resnet import ( + batched_resnet18, + batched_resnet34, + batched_resnet50, + batched_resnet101, + batched_resnet152, + masked_resnet18, + masked_resnet34, + masked_resnet50, + masked_resnet101, + masked_resnet152, + packed_resnet18, + packed_resnet34, + packed_resnet50, + packed_resnet101, + packed_resnet152, + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, +) +from torch_uncertainty.routines.classification import ( + ClassificationEnsemble, + ClassificationSingle, +) + + +# fmt: on +class ResNet: + single = ["vanilla"] + ensemble = ["packed", "batched", "masked"] + versions = { + "vanilla": [resnet18, resnet34, resnet50, resnet101, resnet152], + "packed": [ + packed_resnet18, + packed_resnet34, + packed_resnet50, + packed_resnet101, + packed_resnet152, + ], + "batched": [ + batched_resnet18, + batched_resnet34, + batched_resnet50, + batched_resnet101, + batched_resnet152, + ], + "masked": [ + masked_resnet18, + masked_resnet34, + masked_resnet50, + masked_resnet101, + masked_resnet152, + ], + } + archs = [18, 34, 50, 101, 152] + + def __new__( + cls, + num_classes: int, + in_channels: int, + loss: nn.Module, + optimization_procedure: Any, + version: Literal["vanilla", "packed", "batched", "masked"], + arch: int, + imagenet_structure: bool = True, + num_estimators: Optional[int] = None, + groups: Optional[int] = None, + scale: Optional[float] = None, + alpha: Optional[int] = None, + gamma: Optional[int] = None, + use_entropy: bool = False, + use_logits: bool = False, + use_mi: bool = False, + use_variation_ratio: bool = False, + **kwargs, + ) -> LightningModule: + params = { + "in_channels": in_channels, + "num_classes": num_classes, + "imagenet_structure": imagenet_structure, + } + # version specific parameters + if version == "vanilla": + # TODO: check parameters + params.update({"groups": groups}) + elif version == "packed": + # TODO: check parameters + params.update( + { + "num_estimators": num_estimators, + "alpha": alpha, + "gamma": gamma, + } + ) + elif version == "batched": + # TODO: check parameters + params.update({"num_estimators": num_estimators}) + elif version == "masked": + # TODO: check parameters + params.update( + { + "num_estimators": num_estimators, + "scale": scale, + "groups": groups, + } + ) + else: + raise ValueError(f"Unknown version: {version}") + + model = cls.versions[version][cls.archs.index(arch)](**params) + kwargs.update(params) + print(kwargs) + # routine specific parameters + if version in cls.single: + return ClassificationSingle( + model=model, + loss=loss, + optimization_procedure=optimization_procedure, + use_entropy=use_entropy, + use_logits=use_logits, + **kwargs, + ) + elif version in cls.ensemble: + return ClassificationEnsemble( + model=model, + loss=loss, + optimization_procedure=optimization_procedure, + use_entropy=use_entropy, + use_logits=use_logits, + use_mi=use_mi, + use_variation_ratio=use_variation_ratio, + **kwargs, + ) + + @classmethod + def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--version", + type=str, + choices=cls.versions.keys(), + required=True, + help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", + ) + parser.add_argument( + "--arch", + type=int, + choices=cls.archs, + required=True, + help=f"Architecture of ResNet. Choose among: {cls.archs}", + ) + # parser.add_argument( + # "--imagenet_structure", + # action=BooleanOptionalAction, + # default=True, + # help="Use imagenet structure", + # ) + parser.add_argument( + "--num_estimators", + type=int, + default=None, + help="Number of estimators for ensemble", + ) + parser.add_argument( + "--groups", + type=int, + default=1, + help="Number of groups for vanilla or masked resnet", + ) + parser.add_argument( + "--scale", + type=float, + default=None, + help="Scale for masked resnet", + ) + parser.add_argument( + "--alpha", + type=int, + default=None, + help="Alpha for packed resnet", + ) + parser.add_argument( + "--gamma", + type=int, + default=None, + help="Gamma for packed resnet", + ) + # FIXME: should be a str to choose among the available OOD criteria + # rather than a boolean, but it is not possible since + # ClassificationSingle and ClassificationEnsemble have different OOD + # criteria. + parser.add_argument( + "--entropy", + dest="use_entropy", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--logits", + dest="use_logits", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--mutual_information", + dest="use_mi", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--variation_ratio", + dest="use_variation_ratio", + action=BooleanOptionalAction, + default=False, + ) + + return parser From 663e0b417d07372a95d5ea056931af293afd888b Mon Sep 17 00:00:00 2001 From: alafage Date: Sun, 28 May 2023 00:31:51 +0200 Subject: [PATCH 03/23] Add unique experiment file for all ResNet on CIFAR10 :sparkles: --- experiments/classification/resnet_cifar10.py | 108 +++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 experiments/classification/resnet_cifar10.py diff --git a/experiments/classification/resnet_cifar10.py b/experiments/classification/resnet_cifar10.py new file mode 100644 index 00000000..67224818 --- /dev/null +++ b/experiments/classification/resnet_cifar10.py @@ -0,0 +1,108 @@ +# fmt: off +from argparse import ArgumentParser +from pathlib import Path + +import pytorch_lightning as pl +import torch +import torch.nn as nn +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from torchinfo import summary + +import numpy as np +from torch_uncertainty.baselines import ResNet +from torch_uncertainty.datamodules import CIFAR10DataModule +from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import get_version + +# fmt: on +if __name__ == "__main__": + root = Path(__file__).parent.absolute().parents[1] + + parser = ArgumentParser("torch-uncertainty") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--test", type=int, default=None) + parser.add_argument("--summary", dest="summary", action="store_true") + parser.add_argument("--log_graph", dest="log_graph", action="store_true") + parser.add_argument( + "--channels_last", + action="store_true", + help="Use channels last memory format", + ) + + parser = pl.Trainer.add_argparse_args(parser) + parser = CIFAR10DataModule.add_argparse_args(parser) + parser = ResNet.add_model_specific_args(parser) + args = parser.parse_args() + + # print(args) + + if isinstance(root, str): + root = Path(root) + + if isinstance(args.seed, int): + pl.seed_everything(args.seed) + + net_name = f"{args.version}-resnet{args.arch}-cifar10" + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure(f"resnet{args.arch}", "cifar10"), + imagenet_structure=False, + **vars(args), + ) + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + # logger + tb_logger = TensorBoardLogger( + str(root / "logs"), + name=net_name, + default_hp_metric=False, + log_graph=args.log_graph, + version=args.test, + ) + + # callbacks + save_checkpoints = ModelCheckpoint( + monitor="hp/val_acc", + mode="max", + save_last=True, + save_weights_only=True, + ) + + # Select the best model, monitor the lr and stop if NaN + callbacks = [ + save_checkpoints, + LearningRateMonitor(logging_interval="step"), + EarlyStopping(monitor="hp/val_nll", patience=np.inf, check_finite=True), + ] + # trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=callbacks, + logger=tb_logger, + deterministic=(args.seed is not None), + ) + + if args.summary: + summary(model, input_size=model.example_input_array.shape) + elif args.test is not None: + ckpt_file, _ = get_version( + root=(root / "logs" / net_name), version=args.test + ) + trainer.test(model, datamodule=dm, ckpt_path=str(ckpt_file)) + else: + # training and testing + trainer.fit(model, dm) + trainer.test(datamodule=dm, ckpt_path="best") From 5552cdacf506d636ab7c74cbf6921e391d5d0d37 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 29 May 2023 18:08:38 +0200 Subject: [PATCH 04/23] Update WideResNet baseline :hammer: + Update ResNet baseline :hammer: + Update models :hammer: --- torch_uncertainty/baselines/__init__.py | 5 +- .../baselines/batched/__init__.py | 3 - torch_uncertainty/baselines/batched/resnet.py | 161 -------------- .../baselines/batched/wideresnet.py | 128 ----------- .../baselines/classification/resnet.py | 46 ++-- .../baselines/classification/wideresnet.py | 204 ++++++++++++++++++ .../baselines/masked/__init__.py | 3 - torch_uncertainty/baselines/masked/resnet.py | 149 ------------- .../baselines/masked/wideresnet.py | 116 ---------- torch_uncertainty/baselines/mimo/__init__.py | 0 .../baselines/packed/__init__.py | 3 - torch_uncertainty/baselines/packed/resnet.py | 195 ----------------- .../baselines/packed/wideresnet.py | 110 ---------- .../baselines/standard/__init__.py | 3 - .../baselines/standard/resnet.py | 130 ----------- .../baselines/standard/wideresnet.py | 97 --------- torch_uncertainty/datamodules/cifar10.py | 1 + torch_uncertainty/datamodules/cifar100.py | 1 + torch_uncertainty/models/resnet/batched.py | 4 +- torch_uncertainty/models/resnet/masked.py | 3 +- torch_uncertainty/models/resnet/packed.py | 77 ++++++- .../models/wideresnet/batched.py | 3 +- torch_uncertainty/models/wideresnet/masked.py | 3 +- 23 files changed, 321 insertions(+), 1124 deletions(-) delete mode 100644 torch_uncertainty/baselines/batched/__init__.py delete mode 100644 torch_uncertainty/baselines/batched/resnet.py delete mode 100644 torch_uncertainty/baselines/batched/wideresnet.py create mode 100644 torch_uncertainty/baselines/classification/wideresnet.py delete mode 100644 torch_uncertainty/baselines/masked/__init__.py delete mode 100644 torch_uncertainty/baselines/masked/resnet.py delete mode 100644 torch_uncertainty/baselines/masked/wideresnet.py delete mode 100644 torch_uncertainty/baselines/mimo/__init__.py delete mode 100644 torch_uncertainty/baselines/packed/__init__.py delete mode 100644 torch_uncertainty/baselines/packed/resnet.py delete mode 100644 torch_uncertainty/baselines/packed/wideresnet.py delete mode 100644 torch_uncertainty/baselines/standard/__init__.py delete mode 100644 torch_uncertainty/baselines/standard/resnet.py delete mode 100644 torch_uncertainty/baselines/standard/wideresnet.py diff --git a/torch_uncertainty/baselines/__init__.py b/torch_uncertainty/baselines/__init__.py index 71790374..a90768c1 100644 --- a/torch_uncertainty/baselines/__init__.py +++ b/torch_uncertainty/baselines/__init__.py @@ -1,6 +1,3 @@ # flake8: noqa -# from .batched import BatchedResNet, BatchedWideResNet -# from .masked import MaskedResNet, MaskedWideResNet -# from .packed import PackedResNet, PackedWideResNet -# from .standard import ResNet, WideResNet from .classification.resnet import ResNet +from .classification.wideresnet import WideResNet diff --git a/torch_uncertainty/baselines/batched/__init__.py b/torch_uncertainty/baselines/batched/__init__.py deleted file mode 100644 index 9885dc90..00000000 --- a/torch_uncertainty/baselines/batched/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa -from .resnet import BatchedResNet -from .wideresnet import BatchedWideResNet diff --git a/torch_uncertainty/baselines/batched/resnet.py b/torch_uncertainty/baselines/batched/resnet.py deleted file mode 100644 index 0c4f83ea..00000000 --- a/torch_uncertainty/baselines/batched/resnet.py +++ /dev/null @@ -1,161 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict, Literal - -import torch -import torch.nn as nn -from torch import optim - -from torch_uncertainty.models.resnet import ( - batched_resnet18, - batched_resnet34, - batched_resnet50, - batched_resnet101, - batched_resnet152, -) -from torch_uncertainty.routines.classification import ClassificationEnsemble - -# fmt: on -archs = [ - batched_resnet18, - batched_resnet34, - batched_resnet50, - batched_resnet101, - batched_resnet152, -] -choices = [18, 34, 50, 101, 152] - - -class BatchedResNet(ClassificationEnsemble): - r"""LightningModule for BatchEnsembles ResNet. - - Args: - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - in_channels (int): Number of input channels. - arch (int): - Determines which ResNet architecture to use: - - - ``18``: ResNet-18 - - ``32``: ResNet-32 - - ``50``: ResNet-50 - - ``101``: ResNet-101 - - ``152``: ResNet-152 - - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - """ - - def __init__( - self, - num_classes: int, - num_estimators: int, - in_channels: int, - arch: Literal[18, 34, 50, 101, 152], - loss: nn.Module, - optimization_procedure: Any, - imagenet_structure: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - num_estimators=num_estimators, - **kwargs, - ) - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = archs[choices.index(arch)]( - in_channels=in_channels, - num_estimators=num_estimators, - num_classes=num_classes, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - param_optimizer = self.optimization_procedure(self)["optimizer"] - weight_decay = param_optimizer.defaults["weight_decay"] - lr = param_optimizer.defaults["lr"] - momentum = param_optimizer.defaults["momentum"] - my_list = ["R", "S"] - params_multi_tmp = list( - filter( - lambda kv: (my_list[0] in kv[0]) or (my_list[1] in kv[0]), - self.named_parameters(), - ) - ) - param_core_tmp = list( - filter( - lambda kv: (my_list[0] not in kv[0]) - and (my_list[1] not in kv[0]), - self.named_parameters(), - ) - ) - params_multi = [param for _, param in params_multi_tmp] - param_core = [param for _, param in param_core_tmp] - optimizer = optim.SGD( - [ - {"params": param_core, "weight_decay": weight_decay}, - {"params": params_multi, "weight_decay": 0.0}, - ], - lr=lr, - momentum=momentum, - ) - scheduler = self.optimization_procedure(self)["lr_scheduler"] - scheduler.optimizer = optimizer - return {"optimizer": optimizer, "lr_scheduler": scheduler} - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - input = input.repeat(self.num_estimators, 1, 1, 1) - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--arch [int]``: defines :attr:`arch`. Defaults to ``18``. - - ``--num_estimators [int]``: defines :attr:`num_estimators`. Defaults - to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - Example: - - .. parsed-literal:: - - python script.py --arch 18 --no-imagenet_structure - """ - parent_parser = ClassificationEnsemble.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument( - "--arch", - type=int, - default=18, - choices=choices, - help="Type of ResNet", - ) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - parent_parser.add_argument("--num_estimators", type=int, default=4) - return parent_parser diff --git a/torch_uncertainty/baselines/batched/wideresnet.py b/torch_uncertainty/baselines/batched/wideresnet.py deleted file mode 100644 index 43e1e0ce..00000000 --- a/torch_uncertainty/baselines/batched/wideresnet.py +++ /dev/null @@ -1,128 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict - -import torch -import torch.nn as nn -from torch import optim - -from torch_uncertainty.models.wideresnet.batched import batched_wideresnet28x10 -from torch_uncertainty.routines.classification import ClassificationEnsemble - - -# fmt: on -class BatchedWideResNet(ClassificationEnsemble): - r"""LightningModule for BatchEnsembles WideResNet. - - Args: - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - in_channels (int): Number of input channels. - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - """ - - def __init__( - self, - num_classes: int, - num_estimators: int, - in_channels: int, - loss: nn.Module, - optimization_procedure: Any, - imagenet_structure: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - num_estimators=num_estimators, - **kwargs, - ) - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = batched_wideresnet28x10( - in_channels=in_channels, - num_estimators=num_estimators, - num_classes=num_classes, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - param_optimizer = self.optimization_procedure(self)["optimizer"] - weight_decay = param_optimizer.defaults["weight_decay"] - lr = param_optimizer.defaults["lr"] - momentum = param_optimizer.defaults["momentum"] - my_list = ["R", "S"] - params_multi_tmp = list( - filter( - lambda kv: (my_list[0] in kv[0]) or (my_list[1] in kv[0]), - self.named_parameters(), - ) - ) - param_core_tmp = list( - filter( - lambda kv: (my_list[0] not in kv[0]) - and (my_list[1] not in kv[0]), - self.named_parameters(), - ) - ) - params_multi = [param for _, param in params_multi_tmp] - param_core = [param for _, param in param_core_tmp] - optimizer = optim.SGD( - [ - {"params": param_core, "weight_decay": weight_decay}, - {"params": params_multi, "weight_decay": 0.0}, - ], - lr=lr, - momentum=momentum, - ) - scheduler = self.optimization_procedure(self)["lr_scheduler"] - scheduler.optimizer = optimizer - return {"optimizer": optimizer, "lr_scheduler": scheduler} - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - input = input.repeat(self.num_estimators, 1, 1, 1) - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--num_estimators [int]``: defines :attr:`num_estimators`. Defaults - to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - Example: - - .. parsed-literal:: - - python script.py --num_estimators 4 - """ - parent_parser = ClassificationEnsemble.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument("--num_estimators", type=int, default=4) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - return parent_parser diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 4541d65a..1a7d2980 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -81,6 +81,7 @@ def __new__( use_logits: bool = False, use_mi: bool = False, use_variation_ratio: bool = False, + pretrained: bool = False, **kwargs, ) -> LightningModule: params = { @@ -90,22 +91,42 @@ def __new__( } # version specific parameters if version == "vanilla": - # TODO: check parameters + # TODO: check parameters within a function + if groups < 1: + raise ValueError( + f"Number of groups must be at least 1, not {groups}" + ) params.update({"groups": groups}) elif version == "packed": - # TODO: check parameters + # TODO: check parameters within a function + if alpha <= 0: + raise ValueError( + f"Attribute `alpha` should be > 0, not {alpha}" + ) + if gamma < 1: + raise ValueError( + f"Attribute `gamma` should be >= 1, not {gamma}" + ) params.update( { "num_estimators": num_estimators, "alpha": alpha, "gamma": gamma, + "pretrained": pretrained, } ) elif version == "batched": - # TODO: check parameters params.update({"num_estimators": num_estimators}) elif version == "masked": - # TODO: check parameters + # TODO: check parameters within a function + if scale < 1: + raise ValueError( + f"Attribute `scale` should be >= 1, not {scale}." + ) + if groups < 1: + raise ValueError( + f"Attribute `groups` should be >= 1, not {groups}." + ) params.update( { "num_estimators": num_estimators, @@ -118,7 +139,6 @@ def __new__( model = cls.versions[version][cls.archs.index(arch)](**params) kwargs.update(params) - print(kwargs) # routine specific parameters if version in cls.single: return ClassificationSingle( @@ -147,22 +167,16 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: "--version", type=str, choices=cls.versions.keys(), - required=True, + default="vanilla", help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", ) parser.add_argument( "--arch", type=int, choices=cls.archs, - required=True, + default=18, help=f"Architecture of ResNet. Choose among: {cls.archs}", ) - # parser.add_argument( - # "--imagenet_structure", - # action=BooleanOptionalAction, - # default=True, - # help="Use imagenet structure", - # ) parser.add_argument( "--num_estimators", type=int, @@ -221,5 +235,11 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: action=BooleanOptionalAction, default=False, ) + parser.add_argument( + "--pretrained", + dest="pretrained", + action=BooleanOptionalAction, + default=False, + ) return parser diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py new file mode 100644 index 00000000..d22a5065 --- /dev/null +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -0,0 +1,204 @@ +# fmt: off +from argparse import ArgumentParser, BooleanOptionalAction +from typing import Any, Literal, Optional + +import torch.nn as nn +from pytorch_lightning import LightningModule + +from torch_uncertainty.models.wideresnet import ( + batched_wideresnet28x10, + masked_wideresnet28x10, + packed_wideresnet28x10, + wideresnet28x10, +) +from torch_uncertainty.routines.classification import ( + ClassificationEnsemble, + ClassificationSingle, +) + + +# fmt: on +class WideResNet: + single = ["vanilla"] + ensemble = ["packed", "batched", "masked"] + versions = { + "vanilla": [wideresnet28x10], + "packed": [packed_wideresnet28x10], + "batched": [batched_wideresnet28x10], + "masked": [masked_wideresnet28x10], + } + + def __new__( + cls, + num_classes: int, + in_channels: int, + loss: nn.Module, + optimization_procedure: Any, + version: Literal["vanilla", "packed", "batched", "masked"], + imagenet_structure: bool = True, + num_estimators: Optional[int] = None, + groups: Optional[int] = None, + scale: Optional[float] = None, + alpha: Optional[int] = None, + gamma: Optional[int] = None, + use_entropy: bool = False, + use_logits: bool = False, + use_mi: bool = False, + use_variation_ratio: bool = False, + pretrained: bool = False, + **kwargs, + ) -> LightningModule: + # FIXME: should be a function to avoid repetition + params = { + "in_channels": in_channels, + "num_classes": num_classes, + "imagenet_structure": imagenet_structure, + } + # version specific params + if version == "vanilla": + # TODO: check parameters within a function + if groups < 1: + raise ValueError( + f"Number of groups must be at least 1, not {groups}" + ) + params.update({"groups": groups}) + elif version == "packed": + # TODO: check parameters within a function + if alpha <= 0: + raise ValueError( + f"Attribute `alpha` should be > 0, not {alpha}" + ) + if gamma < 1: + raise ValueError( + f"Attribute `gamma` should be >= 1, not {gamma}" + ) + params.update( + { + "num_estimators": num_estimators, + "alpha": alpha, + "gamma": gamma, + # "pretrained": pretrained, + } + ) + elif version == "batched": + params.update({"num_estimators": num_estimators}) + elif version == "masked": + # TODO: check parameters within a function + if scale < 1: + raise ValueError( + f"Attribute `scale` should be >= 1, not {scale}." + ) + if groups < 1: + raise ValueError( + f"Attribute `groups` should be >= 1, not {groups}." + ) + params.update( + { + "num_estimators": num_estimators, + "scale": scale, + "groups": groups, + } + ) + else: + raise ValueError(f"Unknown version: {version}") + + model = cls.versions[version][0](**params) + kwargs.update(params) + # routine specific parameters + if version in cls.single: + return ClassificationSingle( + model=model, + loss=loss, + optimization_procedure=optimization_procedure, + use_entropy=use_entropy, + use_logits=use_logits, + **kwargs, + ) + elif version in cls.ensemble: + return ClassificationEnsemble( + model=model, + loss=loss, + optimization_procedure=optimization_procedure, + use_entropy=use_entropy, + use_logits=use_logits, + use_mi=use_mi, + use_variation_ratio=use_variation_ratio, + **kwargs, + ) + + @classmethod + def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--version", + type=str, + choices=cls.versions.keys(), + default="vanilla", + help="Variation of WideResNet. " + + f"Choose among: {cls.versions.keys()}", + ) + parser.add_argument( + "--num_estimators", + type=int, + default=None, + help="Number of estimators for ensemble", + ) + parser.add_argument( + "--groups", + type=int, + default=1, + help="Number of groups for vanilla or masked wideresnet", + ) + parser.add_argument( + "--scale", + type=float, + default=None, + help="Scale for masked wideresnet", + ) + parser.add_argument( + "--alpha", + type=int, + default=None, + help="Alpha for packed wideresnet", + ) + parser.add_argument( + "--gamma", + type=int, + default=None, + help="Gamma for packed wideresnet", + ) + # FIXME: should be a str to choose among the available OOD criteria + # rather than a boolean, but it is not possible since + # ClassificationSingle and ClassificationEnsemble have different OOD + # criteria. + parser.add_argument( + "--entropy", + dest="use_entropy", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--logits", + dest="use_logits", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--mutual_information", + dest="use_mi", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--variation_ratio", + dest="use_variation_ratio", + action=BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--pretrained", + dest="pretrained", + action=BooleanOptionalAction, + default=False, + ) + + return parser diff --git a/torch_uncertainty/baselines/masked/__init__.py b/torch_uncertainty/baselines/masked/__init__.py deleted file mode 100644 index 1035ae87..00000000 --- a/torch_uncertainty/baselines/masked/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa -from .resnet import MaskedResNet -from .wideresnet import MaskedWideResNet diff --git a/torch_uncertainty/baselines/masked/resnet.py b/torch_uncertainty/baselines/masked/resnet.py deleted file mode 100644 index 567efe1c..00000000 --- a/torch_uncertainty/baselines/masked/resnet.py +++ /dev/null @@ -1,149 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict, Literal - -import torch -import torch.nn as nn - -from torch_uncertainty.models.resnet import ( - masked_resnet18, - masked_resnet34, - masked_resnet50, - masked_resnet101, - masked_resnet152, -) -from torch_uncertainty.routines.classification import ClassificationEnsemble - -# fmt: on -archs = [ - masked_resnet18, - masked_resnet34, - masked_resnet50, - masked_resnet101, - masked_resnet152, -] -choices = [18, 34, 50, 101, 152] - - -class MaskedResNet(ClassificationEnsemble): - r"""LightningModule for Masksembles ResNet. - - Args: - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - in_channels (int): Number of input channels. - scale (int): Expansion factor affecting the width of the estimators. - groups (int): Number of groups within each estimator. - arch (int): - Determines which ResNet architecture to use: - - - ``18``: ResNet-18 - - ``32``: ResNet-32 - - ``50``: ResNet-50 - - ``101``: ResNet-101 - - ``152``: ResNet-152 - - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - - Raises: - ValueError: If :attr:`scale`:math:`<1`. - ValueError: If :attr:`groups`:math:`<1`. - """ - - def __init__( - self, - num_classes: int, - num_estimators: int, - in_channels: int, - scale: int, - groups: int, - arch: Literal[18, 34, 50, 101, 152], - loss: nn.Module, - optimization_procedure: Any, - imagenet_structure: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - num_estimators=num_estimators, - **kwargs, - ) - - if scale < 1: - raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") - if groups < 1: - raise ValueError( - f"Attribute `groups` should be >= 1, not {groups}." - ) - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = archs[choices.index(arch)]( - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - num_classes=num_classes, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - input = input.repeat(self.num_estimators, 1, 1, 1) - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--arch [int]``: defines :attr:`arch`. Defaults to ``18``. - - ``--num_estimators [int]``: defines :attr:`num_estimators`. Defaults - to ``1``. - - ``--scale [int]``: defines :attr:`scale`. Defaults to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - ``--groups [int]``: defines :attr:`groups`. Defaults to ``1``. - - Example: - - .. parsed-literal:: - - python script.py --arch 18 --num_estimators 4 --scale 2.0 - """ - parent_parser = ClassificationEnsemble.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument( - "--arch", - type=int, - default=18, - choices=choices, - help="Type of ResNet", - ) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - parent_parser.add_argument("--scale", type=float, default=2.0) - return parent_parser diff --git a/torch_uncertainty/baselines/masked/wideresnet.py b/torch_uncertainty/baselines/masked/wideresnet.py deleted file mode 100644 index 0eeb3365..00000000 --- a/torch_uncertainty/baselines/masked/wideresnet.py +++ /dev/null @@ -1,116 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict - -import torch -import torch.nn as nn - -from torch_uncertainty.models.wideresnet.masked import masked_wideresnet28x10 -from torch_uncertainty.routines.classification import ClassificationEnsemble - -# fmt: on - - -class MaskedWideResNet(ClassificationEnsemble): - r"""LightningModule for Masksembles WideResNet. - - Args: - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - in_channels (int): Number of input channels. - scale (int): Expansion factor affecting the width of the estimators. - groups (int): Number of groups within each estimator. - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - """ - - def __init__( - self, - num_classes: int, - num_estimators: int, - in_channels: int, - scale: int, - groups: int, - loss: nn.Module, - optimization_procedure: Any, - use_entropy: bool = False, - use_logits: bool = False, - use_mi: bool = False, - use_variation_ratio: bool = False, - imagenet_structure: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - num_estimators=num_estimators, - use_entropy=use_entropy, - use_logits=use_logits, - use_mi=use_mi, - use_variation_ratio=use_variation_ratio, - **kwargs, - ) - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = masked_wideresnet28x10( - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - num_classes=num_classes, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - input = input.repeat(self.num_estimators, 1, 1, 1) - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--num_estimators [int]``: defines :attr:`num_estimators`. Defaults - to ``1``. - - ``--scale [float]``: defines :attr:`scale`. Defaults to ``2.0``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - ``--groups [int]``: defines :attr:`groups`. Defaults to ``1``. - - Example: - - .. parsed-literal:: - - python script.py --num_estimators 4 --scale 2.0 --groups 1 - """ - parent_parser = ClassificationEnsemble.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument("--num_estimators", type=int, default=4) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - parent_parser.add_argument("--scale", type=float, default=2.0) - parent_parser.add_argument("--groups", type=int, default=1) - return parent_parser diff --git a/torch_uncertainty/baselines/mimo/__init__.py b/torch_uncertainty/baselines/mimo/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/torch_uncertainty/baselines/packed/__init__.py b/torch_uncertainty/baselines/packed/__init__.py deleted file mode 100644 index 5a4fbef6..00000000 --- a/torch_uncertainty/baselines/packed/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa -from .resnet import PackedResNet -from .wideresnet import PackedWideResNet diff --git a/torch_uncertainty/baselines/packed/resnet.py b/torch_uncertainty/baselines/packed/resnet.py deleted file mode 100644 index efdbe7ee..00000000 --- a/torch_uncertainty/baselines/packed/resnet.py +++ /dev/null @@ -1,195 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict, Literal - -import torch -import torch.nn as nn - -from torch_uncertainty.models.resnet import ( - packed_resnet18, - packed_resnet34, - packed_resnet50, - packed_resnet101, - packed_resnet152, -) -from torch_uncertainty.routines.classification import ClassificationEnsemble -from torch_uncertainty.utils import load_hf - -# fmt: on -archs = [ - packed_resnet18, - packed_resnet34, - packed_resnet50, - packed_resnet101, - packed_resnet152, -] -choices = [18, 34, 50, 101, 152] - -weight_ids = { - "10": { - "18": None, - "32": None, - "50": "pe_resnet50_c10", - "101": None, - "152": None, - }, - "100": { - "18": None, - "32": None, - "50": "pe_resnet50_c100", - "101": None, - "152": None, - }, - "1000": { - "18": None, - "32": None, - "50": "pe_resnet50_in1k", - "101": None, - "152": None, - }, - "1000_wider": { - "18": None, - "32": None, - "50": "pex4_resnet50", - "101": None, - "152": None, - }, -} - - -class PackedResNet(ClassificationEnsemble): - r"""LightningModule for Packed-Ensembles ResNet. - - Args: - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - in_channels (int): Number of input channels. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - arch (int): - Determines which ResNet architecture to use: - - - ``18``: ResNet-18 - - ``32``: ResNet-32 - - ``50``: ResNet-50 - - ``101``: ResNet-101 - - ``152``: ResNet-152 - - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - pretrained (bool, optional): Indicates whether to use the pretrained - weights or not. Defaults to ``False``. - ``True``. Otherwise a :class:`ValueError()` will be raised. - - Raises: - ValueError: If :attr:`alpha`:math:`\leq 0`. - ValueError: If :attr:`gamma`:math:`<1`. - """ - - weights_id = "torch-uncertainty/pe_resnet50_in1k" - - def __init__( - self, - num_classes: int, - num_estimators: int, - in_channels: int, - alpha: int, - gamma: int, - arch: Literal[18, 34, 50, 101, 152], - loss: nn.Module, - optimization_procedure: Any, - imagenet_structure: bool = True, - pretrained: bool = False, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - num_estimators=num_estimators, - **kwargs, - ) - - if alpha <= 0: - raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") - if gamma < 1: - raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = archs[choices.index(arch)]( - in_channels=in_channels, - num_estimators=num_estimators, - alpha=alpha, - gamma=gamma, - num_classes=num_classes, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - self._load(pretrained, arch, num_classes) - - def configure_optimizers(self) -> dict: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - return self.model.forward(input) - - def _load(self, pretrained: bool, arch: str, num_classes: int): - if pretrained: - weights = weight_ids[str(num_classes)][arch] - if weights is None: - raise ValueError("No pretrained weights for this configuration") - self.model.load_state_dict(load_hf(weights)) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--arch [int]``: defines :attr:`arch`. Defaults to ``18``. - - ``--num_estimators [int]``: defines :attr:`num_estimators`. Defaults - to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - ``--alpha [int]``: defines :attr:`alpha`. Defaults to ``1``. - - ``--gamma [int]``: defines :attr:`gamma`. Defaults to ``1``. - - Example: - - .. parsed-literal:: - - python script.py --arch 18 --num_estimators 4 --alpha 2 - """ - parent_parser = ClassificationEnsemble.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument( - "--arch", - type=int, - choices=choices, - required=True, - help=f"Type of Packed-ResNet. Choose among {choices}", - ) - parent_parser.add_argument("--num_estimators", type=int, default=4) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - parent_parser.add_argument("--alpha", type=int, default=2) - parent_parser.add_argument("--gamma", type=int, default=1) - return parent_parser diff --git a/torch_uncertainty/baselines/packed/wideresnet.py b/torch_uncertainty/baselines/packed/wideresnet.py deleted file mode 100644 index ffbc97cc..00000000 --- a/torch_uncertainty/baselines/packed/wideresnet.py +++ /dev/null @@ -1,110 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict - -import torch -import torch.nn as nn - -from torch_uncertainty.models.wideresnet.packed import packed_wideresnet28x10 -from torch_uncertainty.routines.classification import ClassificationEnsemble - -# fmt: on - - -class PackedWideResNet(ClassificationEnsemble): - r"""LightningModule for Packed-Ensembles WideResNet. - - Args: - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - in_channels (int): Number of input channels. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - """ - - def __init__( - self, - num_classes: int, - num_estimators: int, - in_channels: int, - alpha: int, - gamma: int, - loss: nn.Module, - optimization_procedure: Any, - imagenet_structure: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - num_estimators=num_estimators, - **kwargs, - ) - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = packed_wideresnet28x10( - in_channels=in_channels, - num_estimators=num_estimators, - num_classes=num_classes, - alpha=alpha, - gamma=gamma, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - """Configures the optimizers. - - Returns: - dict: Optimizers. - """ - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--num_estimators [int]``: defines :attr:`num_estimators`. Defaults - to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - Example: - - .. parsed-literal:: - - python script.py --num_estimators 4 --alpha 2 - """ - parent_parser = ClassificationEnsemble.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument("--num_estimators", type=int, default=4) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - parent_parser.add_argument("--alpha", type=int, default=2) - parent_parser.add_argument("--gamma", type=int, default=1) - return parent_parser diff --git a/torch_uncertainty/baselines/standard/__init__.py b/torch_uncertainty/baselines/standard/__init__.py deleted file mode 100644 index a65fe270..00000000 --- a/torch_uncertainty/baselines/standard/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa -from .resnet import ResNet -from .wideresnet import WideResNet diff --git a/torch_uncertainty/baselines/standard/resnet.py b/torch_uncertainty/baselines/standard/resnet.py deleted file mode 100644 index 69177a10..00000000 --- a/torch_uncertainty/baselines/standard/resnet.py +++ /dev/null @@ -1,130 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any - -import torch -import torch.nn as nn - -from torch_uncertainty.models.resnet import ( - resnet18, - resnet34, - resnet50, - resnet101, - resnet152, -) -from torch_uncertainty.routines.classification import ClassificationSingle - -# fmt: on -archs = [resnet18, resnet34, resnet50, resnet101, resnet152] -choices = [18, 34, 50, 101, 152] - - -class ResNet(ClassificationSingle): - r"""LightningModule for Vanilla ResNet. - - Args: - num_classes (int): Number of classes to predict. - in_channels (int): Number of input channels. - arch (int): - Determines which ResNet architecture to use: - - - ``18``: ResNet-18 - - ``32``: ResNet-32 - - ``50``: ResNet-50 - - ``101``: ResNet-101 - - ``152``: ResNet-152 - - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - groups (int, optional): Number of groups in convolutions. Defaults to - ``1``. - imagenet_structure (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - - Raises: - ValueError: If :attr:`groups` :math:`<1`. - """ - - def __init__( - self, - num_classes: int, - in_channels: int, - arch: int, - loss: nn.Module, - optimization_procedure: Any, - groups: int = 1, - imagenet_structure: bool = True, - **kwargs, - ) -> None: - super().__init__( - num_classes=num_classes, - **kwargs, - ) - - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = archs[choices.index(arch)]( - in_channels=in_channels, - num_classes=num_classes, - groups=groups, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--arch [int]``: defines :attr:`arch`. Defaults to ``18``. - - ``--groups [int]``: defines :attr:`groups`. Defaults to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - Example: - - .. parsed-literal:: - - python script.py --groups 2 --no-imagenet_structure - """ - parent_parser = ClassificationSingle.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument( - "--arch", - type=int, - default=18, - choices=choices, - help="Type of ResNet", - ) - parent_parser.add_argument("--groups", type=int, default=1) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - return parent_parser diff --git a/torch_uncertainty/baselines/standard/wideresnet.py b/torch_uncertainty/baselines/standard/wideresnet.py deleted file mode 100644 index 25618cb4..00000000 --- a/torch_uncertainty/baselines/standard/wideresnet.py +++ /dev/null @@ -1,97 +0,0 @@ -# fmt: off -from argparse import ArgumentParser, BooleanOptionalAction -from typing import Any, Dict - -import torch -import torch.nn as nn - -from torch_uncertainty.models.wideresnet.std import wideresnet28x10 -from torch_uncertainty.routines.classification import ClassificationSingle - -# fmt: on - - -class WideResNet(ClassificationSingle): - r"""LightningModule for Vanilla WideResNet. - - Args: - num_classes (int): Number of classes to predict. - in_channels (int): Number of input channels. - loss (torch.nn.Module): Training loss. - optimization_procedure (Any): Optimization procedure, corresponds to - what expect the `LightningModule.configure_optimizers() - `_ - method. - groups (int, optional): Number of groups in convolutions. Defaults to - ``1``. - """ - - def __init__( - self, - num_classes: int, - in_channels: int, - loss: nn.Module, - optimization_procedure: Any, - groups: int = 1, - imagenet_structure: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__( - num_classes=num_classes, - **kwargs, - ) - - # construct config - self.save_hyperparameters(ignore=["loss", "optimization_procedure"]) - assert groups >= 1 - - self.loss = loss - self.optimization_procedure = optimization_procedure - - self.model = wideresnet28x10( - in_channels=in_channels, - num_classes=num_classes, - groups=groups, - imagenet_structure=imagenet_structure, - ) - - # to log the graph - self.example_input_array = torch.randn(1, in_channels, 32, 32) - - def configure_optimizers(self) -> dict: - return self.optimization_procedure(self) - - @property - def criterion(self) -> nn.Module: - return self.loss() - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - return self.model.forward(input) - - @staticmethod - def add_model_specific_args( - parent_parser: ArgumentParser, - ) -> ArgumentParser: - """Defines the model's attributes via command-line options: - - - ``--groups [int]``: defines :attr:`groups`. Defaults to ``1``. - - ``--imagenet_structure``: sets :attr:`imagenet_structure`. Defaults - to ``True``. - - Example: - - .. parsed-literal:: - - python script.py --groups 2 --no-imagenet_structure - """ - parent_parser = ClassificationSingle.add_model_specific_args( - parent_parser - ) - parent_parser.add_argument("--groups", type=int, default=1) - parent_parser.add_argument( - "--imagenet_structure", - action=BooleanOptionalAction, - default=True, - help="Use imagenet structure", - ) - return parent_parser diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index 07e1409b..04ef93c6 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -18,6 +18,7 @@ class CIFAR10DataModule(LightningDataModule): num_classes = 10 num_channels = 3 + input_shape = (3, 32, 32) def __init__( self, diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index c562193b..30072886 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -19,6 +19,7 @@ class CIFAR100DataModule(LightningDataModule): num_classes = 100 num_channels = 3 + input_shape = (3, 32, 32) def __init__( self, diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index 25c84436..4a0d41f6 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -148,6 +148,7 @@ def __init__( ): super().__init__() self.in_planes = 64 * width_multiplier + self.num_estimators = num_estimators self.width_multiplier = width_multiplier if imagenet_structure: @@ -225,7 +226,8 @@ def _make_layer(self, block, planes, num_blocks, stride, num_estimators): return nn.Sequential(*layers) def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) + out = x.repeat(self.num_estimators, 1, 1, 1) + out = F.relu(self.bn1(self.conv1(out))) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 5eadd799..88e971f1 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -279,7 +279,8 @@ def _make_layer( return nn.Sequential(*layers) def forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.conv1(x))) + out = x.repeat(self.num_estimators, 1, 1, 1) + out = F.relu(self.bn1(self.conv1(out))) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index ab70e079..37671303 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -7,6 +7,7 @@ from torch import Tensor from ...layers import PackedConv2d, PackedLinear +from ...utils import load_hf # fmt: on __all__ = [ @@ -17,6 +18,37 @@ "packed_resnet152", ] +weight_ids = { + "10": { + "18": None, + "32": None, + "50": "pe_resnet50_c10", + "101": None, + "152": None, + }, + "100": { + "18": None, + "32": None, + "50": "pe_resnet50_c100", + "101": None, + "152": None, + }, + "1000": { + "18": None, + "32": None, + "50": "pe_resnet50_in1k", + "101": None, + "152": None, + }, + "1000_wider": { + "18": None, + "32": None, + "50": "pex4_resnet50", + "101": None, + "152": None, + }, +} + class BasicBlock(nn.Module): expansion = 1 @@ -311,6 +343,7 @@ def packed_resnet18( gamma: int, num_classes: int, imagenet_structure: bool = True, + pretrained: bool = False, ) -> _PackedResNet: """Packed-Ensembles of ResNet-18 from `Deep Residual Learning for Image Recognition `_. @@ -325,7 +358,7 @@ def packed_resnet18( Returns: _PackedResNet: A Packed-Ensembles ResNet-18. """ - return _PackedResNet( + net = _PackedResNet( block=BasicBlock, num_blocks=[2, 2, 2, 2], in_channels=in_channels, @@ -335,6 +368,12 @@ def packed_resnet18( num_classes=num_classes, imagenet_structure=imagenet_structure, ) + if pretrained: + weights = weight_ids[str(num_classes)][18] + if weights is None: + raise ValueError("No pretrained weights for this configuration") + net.load_state_dict(load_hf(weights)) + return net def packed_resnet34( @@ -344,6 +383,7 @@ def packed_resnet34( gamma: int, num_classes: int, imagenet_structure: bool = True, + pretrained: bool = False, ) -> _PackedResNet: """Packed-Ensembles of ResNet-34 from `Deep Residual Learning for Image Recognition `_. @@ -358,7 +398,7 @@ def packed_resnet34( Returns: _PackedResNet: A Packed-Ensembles ResNet-34. """ - return _PackedResNet( + net = _PackedResNet( block=BasicBlock, num_blocks=[3, 4, 6, 3], in_channels=in_channels, @@ -368,6 +408,12 @@ def packed_resnet34( num_classes=num_classes, imagenet_structure=imagenet_structure, ) + if pretrained: + weights = weight_ids[str(num_classes)][34] + if weights is None: + raise ValueError("No pretrained weights for this configuration") + net.load_state_dict(load_hf(weights)) + return net def packed_resnet50( @@ -377,6 +423,7 @@ def packed_resnet50( gamma: int, num_classes: int, imagenet_structure: bool = True, + pretrained: bool = False, ) -> _PackedResNet: """Packed-Ensembles of ResNet-50 from `Deep Residual Learning for Image Recognition `_. @@ -391,7 +438,7 @@ def packed_resnet50( Returns: _PackedResNet: A Packed-Ensembles ResNet-50. """ - return _PackedResNet( + net = _PackedResNet( block=Bottleneck, num_blocks=[3, 4, 6, 3], in_channels=in_channels, @@ -401,6 +448,12 @@ def packed_resnet50( num_classes=num_classes, imagenet_structure=imagenet_structure, ) + if pretrained: + weights = weight_ids[str(num_classes)][50] + if weights is None: + raise ValueError("No pretrained weights for this configuration") + net.load_state_dict(load_hf(weights)) + return net def packed_resnet101( @@ -410,6 +463,7 @@ def packed_resnet101( gamma: int, num_classes: int, imagenet_structure: bool = True, + pretrained: bool = False, ) -> _PackedResNet: """Packed-Ensembles of ResNet-101 from `Deep Residual Learning for Image Recognition `_. @@ -424,7 +478,7 @@ def packed_resnet101( Returns: _PackedResNet: A Packed-Ensembles ResNet-101. """ - return _PackedResNet( + net = _PackedResNet( block=Bottleneck, num_blocks=[3, 4, 23, 3], in_channels=in_channels, @@ -434,6 +488,12 @@ def packed_resnet101( num_classes=num_classes, imagenet_structure=imagenet_structure, ) + if pretrained: + weights = weight_ids[str(num_classes)][101] + if weights is None: + raise ValueError("No pretrained weights for this configuration") + net.load_state_dict(load_hf(weights)) + return net def packed_resnet152( @@ -443,6 +503,7 @@ def packed_resnet152( gamma: int, num_classes: int, imagenet_structure: bool = True, + pretrained: bool = False, ) -> _PackedResNet: """Packed-Ensembles of ResNet-152 from `Deep Residual Learning for Image Recognition `_. @@ -459,7 +520,7 @@ def packed_resnet152( Returns: _PackedResNet: A Packed-Ensembles ResNet-152. """ - return _PackedResNet( + net = _PackedResNet( block=Bottleneck, num_blocks=[3, 8, 36, 3], in_channels=in_channels, @@ -469,3 +530,9 @@ def packed_resnet152( num_classes=num_classes, imagenet_structure=imagenet_structure, ) + if pretrained: + weights = weight_ids[str(num_classes)][152] + if weights is None: + raise ValueError("No pretrained weights for this configuration") + net.load_state_dict(load_hf(weights)) + return net diff --git a/torch_uncertainty/models/wideresnet/batched.py b/torch_uncertainty/models/wideresnet/batched.py index b7dbea1d..42686bf8 100644 --- a/torch_uncertainty/models/wideresnet/batched.py +++ b/torch_uncertainty/models/wideresnet/batched.py @@ -173,7 +173,8 @@ def _wide_layer( return nn.Sequential(*layers) def forward(self, x): - out = self.conv1(x) + out = x.repeat(self.num_estimators, 1, 1, 1) + out = self.conv1(out) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) diff --git a/torch_uncertainty/models/wideresnet/masked.py b/torch_uncertainty/models/wideresnet/masked.py index 36a72170..ee2369ac 100644 --- a/torch_uncertainty/models/wideresnet/masked.py +++ b/torch_uncertainty/models/wideresnet/masked.py @@ -189,7 +189,8 @@ def _wide_layer( return nn.Sequential(*layers) def forward(self, x): - out = self.conv1(x) + out = x.repeat(self.num_estimators, 1, 1, 1) + out = self.conv1(out) out = self.optional_pool(out) out = self.layer1(out) out = self.layer2(out) From 406adfca10e7aafadb2cf03cf36c39145f29c22f Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 29 May 2023 18:12:31 +0200 Subject: [PATCH 05/23] Update experiments :hammer: --- experiments/batched/resnet18.py | 23 ---- experiments/classification/cifar10/resnet.py | 33 ++++++ experiments/classification/cifar100/resnet.py | 33 ++++++ experiments/classification/readme.md | 15 +++ experiments/classification/resnet_cifar10.py | 108 ----------------- experiments/experiments.py | 67 ----------- experiments/masked/resnet18.py | 23 ---- experiments/packed/resnet18_cifar10.py | 22 ---- experiments/packed/resnet18_cifar100.py | 22 ---- experiments/packed/resnet50_cifar10.py | 22 ---- experiments/packed/resnet50_cifar100.py | 22 ---- experiments/readme.md | 7 ++ experiments/standard/resnet18_cifar10.py | 22 ---- experiments/standard/resnet18_cifar100.py | 22 ---- experiments/standard/resnet50_cifar10.py | 22 ---- experiments/standard/resnet50_cifar100.py | 22 ---- torch_uncertainty/__init__.py | 109 ++++++------------ 17 files changed, 121 insertions(+), 473 deletions(-) delete mode 100644 experiments/batched/resnet18.py create mode 100644 experiments/classification/cifar10/resnet.py create mode 100644 experiments/classification/cifar100/resnet.py create mode 100644 experiments/classification/readme.md delete mode 100644 experiments/classification/resnet_cifar10.py delete mode 100644 experiments/experiments.py delete mode 100644 experiments/masked/resnet18.py delete mode 100644 experiments/packed/resnet18_cifar10.py delete mode 100644 experiments/packed/resnet18_cifar100.py delete mode 100644 experiments/packed/resnet50_cifar10.py delete mode 100644 experiments/packed/resnet50_cifar100.py create mode 100644 experiments/readme.md delete mode 100644 experiments/standard/resnet18_cifar10.py delete mode 100644 experiments/standard/resnet18_cifar100.py delete mode 100644 experiments/standard/resnet50_cifar10.py delete mode 100644 experiments/standard/resnet50_cifar100.py diff --git a/experiments/batched/resnet18.py b/experiments/batched/resnet18.py deleted file mode 100644 index c8338296..00000000 --- a/experiments/batched/resnet18.py +++ /dev/null @@ -1,23 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.batched import BatchedResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - # print(root) - cli_main( - BatchedResNet, - CIFAR10DataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet18, - root, - "batched", - ) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py new file mode 100644 index 00000000..bd9df801 --- /dev/null +++ b/experiments/classification/cifar10/resnet.py @@ -0,0 +1,33 @@ +# fmt: off +from pathlib import Path + +import torch.nn as nn + +from torch_uncertainty import cls_main, init_args +from torch_uncertainty.baselines import ResNet +from torch_uncertainty.datamodules import CIFAR10DataModule +from torch_uncertainty.optimization_procedures import get_procedure + +# fmt: on +if __name__ == "__main__": + root = Path(__file__).parent.absolute().parents[1] + + args = init_args(ResNet, CIFAR10DataModule) + + net_name = f"{args.version}-resnet{args.arch}-cifar10" + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure(f"resnet{args.arch}", "cifar10"), + imagenet_structure=False, + **vars(args), + ) + + cls_main(model, dm, root, net_name, args) diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py new file mode 100644 index 00000000..de861a93 --- /dev/null +++ b/experiments/classification/cifar100/resnet.py @@ -0,0 +1,33 @@ +# fmt: off +from pathlib import Path + +import torch.nn as nn + +from torch_uncertainty import cls_main, init_args +from torch_uncertainty.baselines import ResNet +from torch_uncertainty.datamodules import CIFAR100DataModule +from torch_uncertainty.optimization_procedures import get_procedure + +# fmt: on +if __name__ == "__main__": + root = Path(__file__).parent.absolute().parents[1] + + args = init_args(ResNet, CIFAR100DataModule) + + net_name = f"{args.version}-resnet{args.arch}-cifar10" + + # datamodule + args.root = str(root / "data") + dm = CIFAR100DataModule(**vars(args)) + + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure(f"resnet{args.arch}", "cifar100"), + imagenet_structure=False, + **vars(args), + ) + + cls_main(model, dm, root, net_name, args) diff --git a/experiments/classification/readme.md b/experiments/classification/readme.md new file mode 100644 index 00000000..8bbc67aa --- /dev/null +++ b/experiments/classification/readme.md @@ -0,0 +1,15 @@ +# Classification Benchmarks + +*Work in progress* + +## Image Classification + +### CIFAR-10 + +* ResNet +* WideResNet + +### CIFAR-100 + +* ResNet +* WideResNet \ No newline at end of file diff --git a/experiments/classification/resnet_cifar10.py b/experiments/classification/resnet_cifar10.py deleted file mode 100644 index 67224818..00000000 --- a/experiments/classification/resnet_cifar10.py +++ /dev/null @@ -1,108 +0,0 @@ -# fmt: off -from argparse import ArgumentParser -from pathlib import Path - -import pytorch_lightning as pl -import torch -import torch.nn as nn -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from torchinfo import summary - -import numpy as np -from torch_uncertainty.baselines import ResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import get_procedure -from torch_uncertainty.utils import get_version - -# fmt: on -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - - parser = ArgumentParser("torch-uncertainty") - parser.add_argument("--seed", type=int, default=None) - parser.add_argument("--test", type=int, default=None) - parser.add_argument("--summary", dest="summary", action="store_true") - parser.add_argument("--log_graph", dest="log_graph", action="store_true") - parser.add_argument( - "--channels_last", - action="store_true", - help="Use channels last memory format", - ) - - parser = pl.Trainer.add_argparse_args(parser) - parser = CIFAR10DataModule.add_argparse_args(parser) - parser = ResNet.add_model_specific_args(parser) - args = parser.parse_args() - - # print(args) - - if isinstance(root, str): - root = Path(root) - - if isinstance(args.seed, int): - pl.seed_everything(args.seed) - - net_name = f"{args.version}-resnet{args.arch}-cifar10" - - # datamodule - args.root = str(root / "data") - dm = CIFAR10DataModule(**vars(args)) - - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure(f"resnet{args.arch}", "cifar10"), - imagenet_structure=False, - **vars(args), - ) - - if args.channels_last: - model = model.to(memory_format=torch.channels_last) - - # logger - tb_logger = TensorBoardLogger( - str(root / "logs"), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=args.test, - ) - - # callbacks - save_checkpoints = ModelCheckpoint( - monitor="hp/val_acc", - mode="max", - save_last=True, - save_weights_only=True, - ) - - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor="hp/val_nll", patience=np.inf, check_finite=True), - ] - # trainer - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - ) - - if args.summary: - summary(model, input_size=model.example_input_array.shape) - elif args.test is not None: - ckpt_file, _ = get_version( - root=(root / "logs" / net_name), version=args.test - ) - trainer.test(model, datamodule=dm, ckpt_path=str(ckpt_file)) - else: - # training and testing - trainer.fit(model, dm) - trainer.test(datamodule=dm, ckpt_path="best") diff --git a/experiments/experiments.py b/experiments/experiments.py deleted file mode 100644 index 14ba34a7..00000000 --- a/experiments/experiments.py +++ /dev/null @@ -1,67 +0,0 @@ -# fmt: off -from argparse import ArgumentParser -from pathlib import Path - -import pytorch_lightning as pl -import torch.nn as nn - -from torch_uncertainty import main -from torch_uncertainty.baselines.batched import BatchedResNet -from torch_uncertainty.baselines.masked import MaskedResNet -from torch_uncertainty.baselines.packed import PackedResNet -from torch_uncertainty.baselines.standard import ResNet -from torch_uncertainty.datamodules import CIFAR10DataModule, CIFAR100DataModule -from torch_uncertainty.optimization_procedures import get_procedure - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - - parser = ArgumentParser("torch-uncertainty") - parser.add_argument("--seed", type=int, default=None) - parser.add_argument("--test", type=int, default=None) - parser.add_argument("--summary", dest="summary", action="store_true") - parser.add_argument("--log_graph", dest="log_graph", action="store_true") - parser.add_argument( - "--type", choices=["standard", "packed", "masked", "batched"] - ) - parser.add_argument( - "--model", choices=["resnet18", "resnet50", "wideresnet28x10"] - ) - parser.add_argument("--data", choices=["cifar10", "cifar100", "imagenet"]) - - parser = pl.Trainer.add_argparse_args(parser) - parser = ResNet.add_model_specific_args(parser) - parser = CIFAR10DataModule.add_argparse_args(parser) - parser = CIFAR100DataModule.add_argparse_args(parser) - args = parser.parse_args() - - if args.data == "cifar10": - datamodule = CIFAR10DataModule - elif args.data == "cifar100": - datamodule = CIFAR100DataModule - elif args.data == "imagenet": - raise NotImplementedError("ImageNet not yet implemented") - else: - raise ValueError(f"Unknown dataset: {args.data}") - - if args.type == "standard": - model_type = ResNet - elif args.type == "masked": - model_type = MaskedResNet - elif args.type == "batched": - model_type = BatchedResNet - elif args.type == "packed": - model_type = PackedResNet - else: - raise ValueError(f"Unknown model type: {args.type}") - - main( - model_type, - datamodule, - nn.CrossEntropyLoss, - get_procedure(args.model, args.data), - root, - f"{args.model}_{args.data}", - ) diff --git a/experiments/masked/resnet18.py b/experiments/masked/resnet18.py deleted file mode 100644 index c724e4b4..00000000 --- a/experiments/masked/resnet18.py +++ /dev/null @@ -1,23 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.masked import MaskedResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - # print(root) - cli_main( - MaskedResNet, - CIFAR10DataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet18, - root, - "masked", - ) diff --git a/experiments/packed/resnet18_cifar10.py b/experiments/packed/resnet18_cifar10.py deleted file mode 100644 index ee54c14d..00000000 --- a/experiments/packed/resnet18_cifar10.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.packed import PackedResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - PackedResNet, - CIFAR10DataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet18, - root, - "packed-resnet18-cifar10", - ) diff --git a/experiments/packed/resnet18_cifar100.py b/experiments/packed/resnet18_cifar100.py deleted file mode 100644 index 779ae44d..00000000 --- a/experiments/packed/resnet18_cifar100.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.packed import PackedResNet -from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import optim_cifar100_resnet18 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - PackedResNet, - CIFAR100DataModule, - nn.CrossEntropyLoss, - optim_cifar100_resnet18, - root, - "packed-resnet18-cifar100", - ) diff --git a/experiments/packed/resnet50_cifar10.py b/experiments/packed/resnet50_cifar10.py deleted file mode 100644 index 2fb4b337..00000000 --- a/experiments/packed/resnet50_cifar10.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.packed import PackedResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet50 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - PackedResNet, - CIFAR10DataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet50, - root, - "packed-resnet50-cifar10", - ) diff --git a/experiments/packed/resnet50_cifar100.py b/experiments/packed/resnet50_cifar100.py deleted file mode 100644 index 6e8c776d..00000000 --- a/experiments/packed/resnet50_cifar100.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.packed import PackedResNet -from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import optim_cifar100_resnet50 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - PackedResNet, - CIFAR100DataModule, - nn.CrossEntropyLoss, - optim_cifar100_resnet50, - root, - "packed-resnet50-cifar100", - ) diff --git a/experiments/readme.md b/experiments/readme.md new file mode 100644 index 00000000..05bf5fc5 --- /dev/null +++ b/experiments/readme.md @@ -0,0 +1,7 @@ +# Experiments + +Torch-Uncertainty proposes various benchmarks to evaluate the uncertainty estimation methods. + +## Classification + +*Work in progress* \ No newline at end of file diff --git a/experiments/standard/resnet18_cifar10.py b/experiments/standard/resnet18_cifar10.py deleted file mode 100644 index 63e3c7b4..00000000 --- a/experiments/standard/resnet18_cifar10.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.standard import ResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - ResNet, - CIFAR10DataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet18, - root, - "std-resnet18-cifar10", - ) diff --git a/experiments/standard/resnet18_cifar100.py b/experiments/standard/resnet18_cifar100.py deleted file mode 100644 index a8e87b0f..00000000 --- a/experiments/standard/resnet18_cifar100.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.standard import ResNet -from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import optim_cifar100_resnet18 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - ResNet, - CIFAR100DataModule, - nn.CrossEntropyLoss, - optim_cifar100_resnet18, - root, - "std-resnet18-cifar100", - ) diff --git a/experiments/standard/resnet50_cifar10.py b/experiments/standard/resnet50_cifar10.py deleted file mode 100644 index fa7eb68a..00000000 --- a/experiments/standard/resnet50_cifar10.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.standard import ResNet -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.optimization_procedures import optim_cifar10_resnet50 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - ResNet, - CIFAR10DataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet50, - root, - "std-resnet50-cifar10", - ) diff --git a/experiments/standard/resnet50_cifar100.py b/experiments/standard/resnet50_cifar100.py deleted file mode 100644 index 21bf018c..00000000 --- a/experiments/standard/resnet50_cifar100.py +++ /dev/null @@ -1,22 +0,0 @@ -# fmt: off -from pathlib import Path - -import torch.nn as nn - -from torch_uncertainty import cli_main -from torch_uncertainty.baselines.standard import ResNet -from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.optimization_procedures import optim_cifar100_resnet50 - -# fmt: on - -if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] - cli_main( - ResNet, - CIFAR100DataModule, - nn.CrossEntropyLoss, - optim_cifar100_resnet50, - root, - "std-resnet50-cifar100", - ) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index b669ecd8..5ce486d4 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -2,7 +2,7 @@ # flake8: noqa from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Callable, Type, Union +from typing import Type, Union import pytorch_lightning as pl import torch @@ -10,21 +10,39 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from torch.nn import Module from torchinfo import summary import numpy as np -from .routines.classification import ClassificationSingle from .utils import get_version # fmt: on -def main( - network: Type[ClassificationSingle], - datamodule: Type[pl.LightningDataModule], - loss: Module, - optimization_procedure: Callable[[Module], dict], +def init_args( + network: Type[pl.LightningModule], datamodule: Type[pl.LightningDataModule] +) -> Namespace: + parser = ArgumentParser("torch-uncertainty") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--test", type=int, default=None) + parser.add_argument("--summary", dest="summary", action="store_true") + parser.add_argument("--log_graph", dest="log_graph", action="store_true") + parser.add_argument( + "--channels_last", + action="store_true", + help="Use channels last memory format", + ) + + parser = pl.Trainer.add_argparse_args(parser) + parser = datamodule.add_argparse_args(parser) + parser = network.add_model_specific_args(parser) + args = parser.parse_args() + + return args + + +def cls_main( + network: pl.LightningModule, + datamodule: pl.LightningDataModule, root: Union[Path, str], net_name: str, args: Namespace, @@ -32,31 +50,11 @@ def main( if isinstance(root, str): root = Path(root) - if args.seed: - pl.seed_everything(args.seed, workers=True) - - if args.max_epochs is None: - print( - "Setting max_epochs to 1 for testing purposes. Set max_epochs " - "manually to train the model." - ) - args.max_epochs = 1 - - # datamodule - args.root = str(root / "data") - dm = datamodule(**vars(args)) - - # model - model = network( - loss=loss, - optimization_procedure=optimization_procedure, - num_classes=dm.num_classes, - in_channels=dm.num_channels, - **vars(args), - ) + if isinstance(args.seed, int): + pl.seed_everything(args.seed) if args.channels_last: - model = model.to(memory_format=torch.channels_last) + network = network.to(memory_format=torch.channels_last) # logger tb_logger = TensorBoardLogger( @@ -81,7 +79,6 @@ def main( LearningRateMonitor(logging_interval="step"), EarlyStopping(monitor="hp/val_nll", patience=np.inf, check_finite=True), ] - # trainer trainer = pl.Trainer.from_argparse_args( args, @@ -91,53 +88,13 @@ def main( ) if args.summary: - summary(model, input_size=model.example_input_array.shape) + summary(network, input_size=list(datamodule.input_shape).insert(0, 1)) elif args.test is not None: ckpt_file, _ = get_version( root=(root / "logs" / net_name), version=args.test ) - trainer.test(model, datamodule=dm, ckpt_path=str(ckpt_file)) + trainer.test(network, datamodule=datamodule, ckpt_path=str(ckpt_file)) else: # training and testing - trainer.fit(model, dm) - trainer.test(datamodule=dm, ckpt_path="best") - - -def cli_main( - network: Type[ClassificationSingle], - datamodule: Type[pl.LightningDataModule], - loss: Module, - optimization_procedure: Callable[[Module], dict], - root: Union[Path, str], - net_name: str, -) -> None: - parser = ArgumentParser("torch-uncertainty") - parser.add_argument( - "--seed", - type=int, - default=None, - help="Set the random seed to some value for reproducibility", - ) - parser.add_argument( - "--test", - type=int, - default=None, - help="Test a specific version of the model. The checkpoint must be available in the logs folder.", - ) - parser.add_argument( - "--summary", action="store_true", help="Print a summary of the model" - ) - parser.add_argument("--log_graph", action="store_true") - parser.add_argument( - "--channels_last", - action="store_true", - help="Use channels last memory format", - ) - - parser = pl.Trainer.add_argparse_args(parser) - parser = datamodule.add_argparse_args(parser) - parser = network.add_model_specific_args(parser) - args = parser.parse_args() - main( - network, datamodule, loss, optimization_procedure, root, net_name, args - ) + trainer.fit(network, datamodule) + trainer.test(datamodule=datamodule, ckpt_path="best") From 662366c74d20a8c11cb7f1585470371f8439bdb2 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 29 May 2023 18:13:40 +0200 Subject: [PATCH 06/23] Update tests accordingly to changes :hammer: --- tests/baselines/test_batched.py | 29 ++++++------- tests/baselines/test_masked.py | 61 ++++++++++++++------------- tests/baselines/test_packed.py | 56 ++++++++++++------------- tests/baselines/test_standard.py | 24 +++++------ tests/test_cli.py | 71 ++++++++++---------------------- 5 files changed, 101 insertions(+), 140 deletions(-) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index 88bfe599..8d550fc5 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -1,11 +1,9 @@ # fmt:off -from argparse import ArgumentParser - import torch import torch.nn as nn from torchinfo import summary -from torch_uncertainty.baselines.batched import BatchedResNet, BatchedWideResNet +from torch_uncertainty.baselines import ResNet, WideResNet from torch_uncertainty.optimization_procedures import ( optim_cifar10_wideresnet, optim_cifar100_resnet50, @@ -17,18 +15,17 @@ class TestBatchedBaseline: """Testing the BatchedResNet baseline class.""" def test_batched(self): - net = BatchedResNet( - arch=18, - in_channels=3, + net = ResNet( num_classes=10, - num_estimators=4, + in_channels=3, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar100_resnet50, + version="batched", + arch=18, imagenet_structure=False, + num_estimators=4, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) + summary(net) _ = net.criterion @@ -40,19 +37,17 @@ class TestBatchedWideBaseline: """Testing the BatchedWideResNet baseline class.""" def test_batched(self): - net = BatchedWideResNet( - arch=18, + net = WideResNet( num_classes=10, - num_estimators=4, in_channels=3, - groups=1, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_wideresnet, + version="batched", imagenet_structure=False, + num_estimators=4, + groups=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) + summary(net) _ = net.criterion diff --git a/tests/baselines/test_masked.py b/tests/baselines/test_masked.py index ef86280e..f5210b5a 100644 --- a/tests/baselines/test_masked.py +++ b/tests/baselines/test_masked.py @@ -1,12 +1,10 @@ # fmt:off -from argparse import ArgumentParser - import pytest import torch import torch.nn as nn from torchinfo import summary -from torch_uncertainty.baselines.masked import MaskedResNet, MaskedWideResNet +from torch_uncertainty.baselines import ResNet, WideResNet from torch_uncertainty.optimization_procedures import ( optim_cifar10_wideresnet, optim_cifar100_resnet18, @@ -18,20 +16,19 @@ class TestMaskedBaseline: """Testing the MaskedResNet baseline class.""" def test_masked(self): - net = MaskedResNet( - arch=18, - in_channels=3, + net = ResNet( num_classes=10, - num_estimators=4, - scale=2, - groups=1, + in_channels=3, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar100_resnet18, + version="masked", + arch=18, imagenet_structure=False, + num_estimators=4, + scale=2, + groups=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) + summary(net) _ = net.criterion @@ -40,28 +37,32 @@ def test_masked(self): def test_masked_scale_lt_1(self): with pytest.raises(Exception): - _ = MaskedResNet( - arch=18, - in_channels=3, + _ = ResNet( num_classes=10, + in_channels=3, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar100_resnet18, + version="masked", + arch=18, + imagenet_structure=False, num_estimators=4, scale=0.5, groups=1, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, ) def test_masked_groups_lt_1(self): with pytest.raises(Exception): - _ = MaskedResNet( - arch=18, - in_channels=3, + _ = ResNet( num_classes=10, + in_channels=3, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar100_resnet18, + version="masked", + arch=18, + imagenet_structure=False, num_estimators=4, scale=2, groups=0, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar100_resnet18, ) @@ -69,20 +70,18 @@ class TestMaskedWideBaseline: """Testing the MaskedWideResNet baseline class.""" def test_masked(self): - net = MaskedWideResNet( - arch=18, - in_channels=3, + net = WideResNet( num_classes=10, - num_estimators=4, - scale=2, - groups=1, + in_channels=3, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_wideresnet, + version="masked", imagenet_structure=False, + num_estimators=4, + scale=2, + groups=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) + summary(net) _ = net.criterion diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index 85909223..f4f843cd 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -1,12 +1,10 @@ # fmt:off -from argparse import ArgumentParser - import pytest import torch import torch.nn as nn from torchinfo import summary -from torch_uncertainty.baselines.packed import PackedResNet, PackedWideResNet +from torch_uncertainty.baselines import ResNet, WideResNet from torch_uncertainty.optimization_procedures import ( optim_cifar10_resnet50, optim_cifar10_wideresnet, @@ -18,20 +16,19 @@ class TestPackedBaseline: """Testing the PackedResNet baseline class.""" def test_packed(self): - net = PackedResNet( + net = ResNet( num_classes=10, - num_estimators=4, in_channels=3, - alpha=2, - gamma=1, - arch=50, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet50, + version="packed", + arch=50, imagenet_structure=False, + num_estimators=4, + alpha=2, + gamma=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--arch", "50", "--no-imagenet_structure"]) + summary(net) _ = net.criterion @@ -40,28 +37,32 @@ def test_packed(self): def test_packed_alpha_lt_0(self): with pytest.raises(Exception): - _ = PackedResNet( + _ = ResNet( num_classes=10, - num_estimators=4, in_channels=3, - alpha=0, - gamma=1, - arch=50, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet50, + version="packed", + arch=50, + imagenet_structure=False, + num_estimators=4, + alpha=0, + gamma=1, ) def test_packed_gamma_lt_1(self): with pytest.raises(Exception): - _ = PackedResNet( + _ = ResNet( num_classes=10, - num_estimators=4, in_channels=3, - alpha=2, - gamma=0, - arch=50, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet50, + version="packed", + arch=50, + imagenet_structure=False, + num_estimators=4, + alpha=2, + gamma=0, ) @@ -69,19 +70,18 @@ class TestPackedWideBaseline: """Testing the PackedWideResNet baseline class.""" def test_packed(self): - net = PackedWideResNet( + net = WideResNet( num_classes=10, - num_estimators=4, in_channels=3, - alpha=2, - gamma=1, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_wideresnet, + version="packed", imagenet_structure=False, + num_estimators=4, + alpha=2, + gamma=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) + summary(net) _ = net.criterion diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 4b7563cd..34546f82 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -1,12 +1,9 @@ # fmt:off - -from argparse import ArgumentParser - import torch import torch.nn as nn from torchinfo import summary -from torch_uncertainty.baselines.standard import ResNet, WideResNet +from torch_uncertainty.baselines import ResNet, WideResNet from torch_uncertainty.optimization_procedures import ( optim_cifar10_resnet18, optim_cifar10_wideresnet, @@ -21,15 +18,13 @@ def test_standard(self): net = ResNet( num_classes=10, in_channels=3, - groups=1, - arch=34, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_resnet18, + version="vanilla", + arch=18, imagenet_structure=False, + groups=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) summary(net) _ = net.criterion @@ -40,18 +35,19 @@ def test_standard(self): class TestStandardWideBaseline: """Testing the WideResNet baseline class.""" - def test_packed(self): + def test_standard(self): net = WideResNet( num_classes=10, in_channels=3, - groups=1, loss=nn.CrossEntropyLoss, optimization_procedure=optim_cifar10_wideresnet, + version="vanilla", imagenet_structure=False, + groups=1, ) - parser = ArgumentParser("torch-uncertainty-test") - parser = net.add_model_specific_args(parser) - parser.parse_args(["--no-imagenet_structure"]) + # parser = ArgumentParser("torch-uncertainty-test") + # parser = net.add_model_specific_args(parser) + # parser.parse_args(["--no-imagenet_structure"]) summary(net) _ = net.criterion diff --git a/tests/test_cli.py b/tests/test_cli.py index e5c72e52..0daf9bc0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,67 +1,38 @@ # fmt: off -from argparse import ArgumentParser from pathlib import Path -import pytorch_lightning as pl import torch.nn as nn from cli_test_helpers import ArgvContext -from torch_uncertainty import cli_main, main -from torch_uncertainty.baselines.standard import ResNet +from torch_uncertainty import cls_main, init_args +from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 -from ._dummies import Dummy, DummyDataModule - # fmt: on class TestCLI: """Testing the CLI function.""" - def test_main_summary(self): - root = Path(__file__).parent.absolute().parents[0] - - parser = ArgumentParser("torch-uncertainty") - parser.add_argument("--seed", type=int, default=None) - parser.add_argument("--test", type=int, default=None) - parser.add_argument("--summary", dest="summary", action="store_true") - parser.add_argument( - "--log_graph", dest="log_graph", action="store_true" - ) - parser.add_argument( - "--channels_last", - action="store_true", - help="Use channels last memory format", - ) - - datamodule = CIFAR10DataModule - network = ResNet - parser = pl.Trainer.add_argparse_args(parser) - parser = datamodule.add_argparse_args(parser) - parser = network.add_model_specific_args(parser) - - # Simulate that summary is True & the only argument - args = parser.parse_args(["--no-imagenet_structure"]) - args.summary = True - - main( - network, - datamodule, - nn.CrossEntropyLoss, - optim_cifar10_resnet18, - root, - "std", - args, - ) - - def test_cli_main(self): + def test_cls_main_summary(self): root = Path(__file__).parent.absolute().parents[0] with ArgvContext(""): - cli_main( - Dummy, - DummyDataModule, - nn.CrossEntropyLoss, - optim_cifar10_resnet18, - root, - "dummy", + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + imagenet_structure=False, + **vars(args), ) + + cls_main(model, dm, root, "std", args) From 1d1db24fd1a27293fdb982efa2ca3a0888853aac Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 29 May 2023 18:16:33 +0200 Subject: [PATCH 07/23] Update API reference :books: --- docs/source/api.rst | 38 ++------------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 0fdf8108..8bd25dd2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -10,8 +10,8 @@ This API provides lightning-based models that can be easily trained and evaluate .. currentmodule:: torch_uncertainty.baselines -Vanilla -^^^^^^^ +Classification +^^^^^^^^^^^^^^ .. autosummary:: :toctree: generated/ @@ -21,40 +21,6 @@ Vanilla ResNet WideResNet -Packed-Ensembles -^^^^^^^^^^^^^^^^ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: class.rst - - PackedResNet - PackedWideResNet - -Masksembles -^^^^^^^^^^^ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: class.rst - - MaskedResNet - MaskedWideResNet - -BatchEnsemble -^^^^^^^^^^^^^ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: class.rst - - - BatchedResNet - BatchedWideResNet - Models ------ From a7c8021cbc6c1f3c7d7f49a80f31ff1860ebe679 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 11:28:51 +0200 Subject: [PATCH 08/23] Add support for BatchEnsemble :sparkles: --- experiments/classification/cifar10/resnet.py | 6 +- experiments/classification/cifar100/resnet.py | 6 +- torch_uncertainty/optimization_procedures.py | 93 +++++++++++++++---- 3 files changed, 84 insertions(+), 21 deletions(-) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index bd9df801..db1ee687 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -10,7 +10,7 @@ # fmt: on if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] + root = Path(__file__).parent.absolute().parents[2] args = init_args(ResNet, CIFAR10DataModule) @@ -25,7 +25,9 @@ num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure(f"resnet{args.arch}", "cifar10"), + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar10", args.version + ), imagenet_structure=False, **vars(args), ) diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index de861a93..0f780638 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -10,7 +10,7 @@ # fmt: on if __name__ == "__main__": - root = Path(__file__).parent.absolute().parents[1] + root = Path(__file__).parent.absolute().parents[2] args = init_args(ResNet, CIFAR100DataModule) @@ -25,7 +25,9 @@ num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure(f"resnet{args.arch}", "cifar100"), + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar100", args.version + ), imagenet_structure=False, **vars(args), ) diff --git a/torch_uncertainty/optimization_procedures.py b/torch_uncertainty/optimization_procedures.py index 86da5bb2..f9adf650 100644 --- a/torch_uncertainty/optimization_procedures.py +++ b/torch_uncertainty/optimization_procedures.py @@ -1,4 +1,7 @@ # fmt: off +from functools import partial +from typing import Callable + import torch.nn as nn import torch.optim as optim from timm.optim import Lamb @@ -83,7 +86,7 @@ def optim_cifar100_resnet18(model: nn.Module) -> dict: return {"optimizer": optimizer, "lr_scheduler": scheduler} -def optim_cifar100_resnet50(model: nn.Module, adam: bool = False) -> dict: +def optim_cifar100_resnet50(model: nn.Module) -> dict: r"""Hyperparameters from Deep Residual Learning for Image Recognition https://arxiv.org/pdf/1512.03385.pdf """ @@ -104,7 +107,7 @@ def optim_cifar100_resnet50(model: nn.Module, adam: bool = False) -> dict: def optim_imagenet_resnet50( model: nn.Module, - n_epochs: int = 90, + num_epochs: int = 90, start_lr: float = 0.256, end_lr: float = 0, ) -> dict: @@ -119,7 +122,7 @@ def optim_imagenet_resnet50( nesterov=False, ) scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, n_epochs, eta_min=end_lr + optimizer, num_epochs, eta_min=end_lr ) return { "optimizer": optimizer, @@ -176,17 +179,73 @@ def optim_imagenet_resnet50_A3( } -def get_procedure(model_name, dm_name): - if model_name == "resnet18": - if dm_name == "cifar10": - return optim_cifar10_resnet18 - elif dm_name == "cifar100": - return optim_cifar100_resnet18 - elif model_name == "resnet50": - if dm_name == "cifar10": - return optim_cifar10_resnet50 - elif dm_name == "cifar100": - return optim_cifar100_resnet50 - elif model_name == "wideresnet28x10": - if dm_name == "cifar10" or dm_name == "cifar100": - return optim_cifar10_wideresnet +def batch_ensemble_wrapper(model: nn.Module, optimization_procedure: Callable): + procedure = optimization_procedure(model) + param_optimizer = procedure["optimizer"] + scheduler = procedure["lr_scheduler"] + + weight_decay = param_optimizer.defaults["weight_decay"] + lr = param_optimizer.defaults["lr"] + momentum = param_optimizer.defaults["momentum"] + + name_list = ["R", "S"] + params_multi_tmp = list( + filter( + lambda kv: (name_list[0] in kv[0]) or (name_list[1] in kv[0]), + model.named_parameters(), + ) + ) + param_core_tmp = list( + filter( + lambda kv: (name_list[0] not in kv[0]) + and (name_list[1] not in kv[0]), + model.named_parameters(), + ) + ) + + params_multi = [param for _, param in params_multi_tmp] + param_core = [param for _, param in param_core_tmp] + optimizer = optim.SGD( + [ + {"params": param_core, "weight_decay": weight_decay}, + {"params": params_multi, "weight_decay": 0.0}, + ], + lr=lr, + momentum=momentum, + ) + + scheduler.optimizer = optimizer + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def get_procedure( + arch_name: str, ds_name: str, model_name: str = "" +) -> Callable: + """Get the optimization procedure for a given architecture and dataset. + + Args: + arch_name (str): The name of the architecture. + ds_name (str): The name of the dataset. + model_name (str, optional): The name of the model. Defaults to "". + + Returns: + callable: The optimization procedure. + """ + if arch_name == "resnet18": + if ds_name == "cifar10": + procedure = optim_cifar10_resnet18 + elif ds_name == "cifar100": + procedure = optim_cifar100_resnet18 + elif arch_name == "resnet50": + if ds_name == "cifar10": + procedure = optim_cifar10_resnet50 + elif ds_name == "cifar100": + procedure = optim_cifar100_resnet50 + elif arch_name == "wideresnet28x10": + if ds_name == "cifar10" or ds_name == "cifar100": + procedure = optim_cifar10_wideresnet + + if model_name == "batch_ensemble": + procedure = partial(batch_ensemble_wrapper, {"procedure": procedure}) + + return procedure From c80d2056df9bbf700a47a424f2b5896f34e45cbf Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 11:29:50 +0200 Subject: [PATCH 09/23] Add BastchEnsembles & TempScaling to Rdme :book: Fix md violations :shirt: --- README.md | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1ab4bf97..635efaf9 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ _TorchUncertainty_ is a package designed to help you leverage uncertainty quanti --- This package provides a multi-level API, including: + - ready-to-train baselines on research datasets, such as ImageNet and CIFAR - baselines available for training on your datasets - [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧). @@ -38,16 +39,24 @@ Please find the documentation at [torch-uncertainty.github.io](https://torch-unc A quickstart is available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). -## Implemented baselines +## Implemented methods + +### Baselines To date, the following baselines are implemented: - Deep Ensembles +- BatchEnsemble - Masksembles - Packed-Ensembles -## Tutorials +### Post-processing methods + +To date, the following post-processing methods are implemented: +- Temperature scaling + +## Tutorials ## Awesome Uncertainty repositories @@ -58,10 +67,12 @@ You may find a lot of information about modern uncertainty estimation techniques This package also contains the official implementation of Packed-Ensembles. If you find the corresponding models interesting, please consider citing our [paper](https://arxiv.org/abs/2210.09184): - - @inproceedings{laurent2023packed, - title={Packed-Ensembles for Efficient Uncertainty Estimation}, - author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni}, - booktitle={ICLR}, - year={2023} - } + +```text +@inproceedings{laurent2023packed, + title={Packed-Ensembles for Efficient Uncertainty Estimation}, + author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni}, + booktitle={ICLR}, + year={2023} +} +``` From f969be20c87b2744b0a79cf84681348b62f53d6c Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 11:35:18 +0200 Subject: [PATCH 10/23] Add wideresnet experiments :sparkles: --- .../classification/cifar10/wideresnet.py | 35 +++++++++++++++++++ .../classification/cifar100/wideresnet.py | 35 +++++++++++++++++++ experiments/classification/readme.md | 2 +- 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 experiments/classification/cifar10/wideresnet.py create mode 100644 experiments/classification/cifar100/wideresnet.py diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py new file mode 100644 index 00000000..367a54dd --- /dev/null +++ b/experiments/classification/cifar10/wideresnet.py @@ -0,0 +1,35 @@ +# fmt: off +from pathlib import Path + +import torch.nn as nn + +from torch_uncertainty import cls_main, init_args +from torch_uncertainty.baselines import WideResNet +from torch_uncertainty.datamodules import CIFAR10DataModule +from torch_uncertainty.optimization_procedures import get_procedure + +# fmt: on +if __name__ == "__main__": + root = Path(__file__).parent.absolute().parents[2] + + args = init_args(WideResNet, CIFAR10DataModule) + + net_name = f"{args.version}-wideresnet{args.arch}-cifar10" + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # model + model = WideResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar10", args.version + ), + imagenet_structure=False, + **vars(args), + ) + + cls_main(model, dm, root, net_name, args) diff --git a/experiments/classification/cifar100/wideresnet.py b/experiments/classification/cifar100/wideresnet.py new file mode 100644 index 00000000..07c54aa7 --- /dev/null +++ b/experiments/classification/cifar100/wideresnet.py @@ -0,0 +1,35 @@ +# fmt: off +from pathlib import Path + +import torch.nn as nn + +from torch_uncertainty import cls_main, init_args +from torch_uncertainty.baselines import WideResNet +from torch_uncertainty.datamodules import CIFAR100DataModule +from torch_uncertainty.optimization_procedures import get_procedure + +# fmt: on +if __name__ == "__main__": + root = Path(__file__).parent.absolute().parents[2] + + args = init_args(WideResNet, CIFAR100DataModule) + + net_name = f"{args.version}-wideresnet{args.arch}-cifar10" + + # datamodule + args.root = str(root / "data") + dm = CIFAR100DataModule(**vars(args)) + + # model + model = WideResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar100", args.version + ), + imagenet_structure=False, + **vars(args), + ) + + cls_main(model, dm, root, net_name, args) diff --git a/experiments/classification/readme.md b/experiments/classification/readme.md index 8bbc67aa..097aeec2 100644 --- a/experiments/classification/readme.md +++ b/experiments/classification/readme.md @@ -12,4 +12,4 @@ ### CIFAR-100 * ResNet -* WideResNet \ No newline at end of file +* WideResNet From 981d7eaeb817e176724e4cce8c2e8952876f59f5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 12:38:55 +0200 Subject: [PATCH 11/23] Fix BatchEnsembles optimizer :bug: --- torch_uncertainty/optimization_procedures.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/optimization_procedures.py b/torch_uncertainty/optimization_procedures.py index f9adf650..dcc8cbe6 100644 --- a/torch_uncertainty/optimization_procedures.py +++ b/torch_uncertainty/optimization_procedures.py @@ -245,7 +245,9 @@ def get_procedure( if ds_name == "cifar10" or ds_name == "cifar100": procedure = optim_cifar10_wideresnet - if model_name == "batch_ensemble": - procedure = partial(batch_ensemble_wrapper, {"procedure": procedure}) + if model_name == "batched": + procedure = partial( + batch_ensemble_wrapper, optimization_procedure=procedure + ) return procedure From 3dfbe5c0bf8f2bfa4b1516507796bd01c26caf37 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 13:17:55 +0200 Subject: [PATCH 12/23] Fix experiment name in cifar100 :bug: --- experiments/classification/cifar100/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index 0f780638..bfb96b6b 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -14,7 +14,7 @@ args = init_args(ResNet, CIFAR100DataModule) - net_name = f"{args.version}-resnet{args.arch}-cifar10" + net_name = f"{args.version}-resnet{args.arch}-cifar100" # datamodule args.root = str(root / "data") From a9d3c26ffd560bf3d1f9ae2557d46070dd03bd56 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 17:47:00 +0200 Subject: [PATCH 13/23] Revert del. of PL override of None num_epochs => 1k epochs :hammer: --- torch_uncertainty/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 5ce486d4..9dfbd5c7 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -50,6 +50,13 @@ def cls_main( if isinstance(root, str): root = Path(root) + if args.max_epochs is None: + print( + "Setting max_epochs to 1 for testing purposes. Set max_epochs " + "manually to train the model." + ) + args.max_epochs = 1 + if isinstance(args.seed, int): pl.seed_everything(args.seed) From e9ae77719589d6a676207e7ecc97fadc8fdaf931 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 23:09:02 +0200 Subject: [PATCH 14/23] Use get_procedure in opt. proc. tests :heavy_check_mark: --- tests/test_optimization_procedures.py | 30 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_optimization_procedures.py b/tests/test_optimization_procedures.py index b6cdf630..fade93bb 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optimization_procedures.py @@ -1,36 +1,42 @@ # flake8: noqa # fmt: off from torch_uncertainty.models.resnet import resnet18, resnet50 - -# from torch_uncertainty.models.wideresnet import resnet18, resnet50 +from torch_uncertainty.models.wideresnet import wideresnet28x10 from torch_uncertainty.optimization_procedures import * +from torch_uncertainty.optimization_procedures import get_procedure # fmt: on class TestOptProcedures: def test_optim_cifar10_resnet18(self): + procedure = get_procedure("resnet18", "cifar10", "standard") model = resnet18(in_channels=3, num_classes=10) - optim_cifar10_resnet18(model) + procedure(model) def test_optim_cifar10_resnet50(self): + procedure = get_procedure("resnet50", "cifar10", "packed") model = resnet50(in_channels=3, num_classes=10) - optim_cifar10_resnet50(model) + procedure(model) - # def test_optim_cifar10_wideresnet(self): - # model = resnet50() - # optim_cifar10_wideresnet(model) + def test_optim_cifar10_wideresnet(self): + procedure = get_procedure("wideresnet28x10", "cifar100", "batched") + model = wideresnet28x10(in_channels=3, num_classes=10) + procedure(model) def test_optim_cifar100_resnet18(self): + procedure = get_procedure("resnet50", "cifar100", "masked") model = resnet18(in_channels=3, num_classes=100) - optim_cifar100_resnet18(model) + procedure(model) def test_optim_cifar100_resnet50(self): + procedure = get_procedure("resnet50", "cifar100") model = resnet50(in_channels=3, num_classes=100) - optim_cifar100_resnet50(model) + procedure(model) - # def test_optim_cifar100_wideresnet(self): - # model = resnet50() - # optim_cifar10_wideresnet(model) + def test_optim_cifar100_wideresnet(self): + procedure = get_procedure("wideresnet28x10", "cifar100") + model = wideresnet28x10(in_channels=3, num_classes=100) + procedure(model) def test_optim_imagenet_resnet50(self): model = resnet50(in_channels=3, num_classes=1000) From 2147bcad952d32105428075d9c08828b1a61251e Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 30 May 2023 23:14:56 +0200 Subject: [PATCH 15/23] Second CLI test with different arguments :heavy_check_mark: --- tests/test_cli.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 0daf9bc0..7e22b9da 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -36,3 +36,26 @@ def test_cls_main_summary(self): ) cls_main(model, dm, root, "std", args) + + def test_cls_main_other_arguments(self): + root = Path(__file__).parent.absolute().parents[0] + with ArgvContext("--seed 42 --max_epochs 1 --channels_last"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = root / "data" + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + imagenet_structure=False, + **vars(args), + ) + + cls_main(model, dm, root, "std", args) From 5236c70549b7153fe540aa92d0539e7fc3144349 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 31 May 2023 12:32:49 +0200 Subject: [PATCH 16/23] Factorize OOD criterion arguments :hammer: --- .../baselines/classification/resnet.py | 36 +------------------ .../baselines/classification/wideresnet.py | 29 +-------------- 2 files changed, 2 insertions(+), 63 deletions(-) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 1a7d2980..b1dd8e77 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -163,6 +163,7 @@ def __new__( @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser = ClassificationEnsemble.add_model_specific_args(parser) parser.add_argument( "--version", type=str, @@ -177,12 +178,6 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: default=18, help=f"Architecture of ResNet. Choose among: {cls.archs}", ) - parser.add_argument( - "--num_estimators", - type=int, - default=None, - help="Number of estimators for ensemble", - ) parser.add_argument( "--groups", type=int, @@ -207,39 +202,10 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: default=None, help="Gamma for packed resnet", ) - # FIXME: should be a str to choose among the available OOD criteria - # rather than a boolean, but it is not possible since - # ClassificationSingle and ClassificationEnsemble have different OOD - # criteria. - parser.add_argument( - "--entropy", - dest="use_entropy", - action=BooleanOptionalAction, - default=False, - ) - parser.add_argument( - "--logits", - dest="use_logits", - action=BooleanOptionalAction, - default=False, - ) - parser.add_argument( - "--mutual_information", - dest="use_mi", - action=BooleanOptionalAction, - default=False, - ) - parser.add_argument( - "--variation_ratio", - dest="use_variation_ratio", - action=BooleanOptionalAction, - default=False, - ) parser.add_argument( "--pretrained", dest="pretrained", action=BooleanOptionalAction, default=False, ) - return parser diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index d22a5065..e5acd6ce 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -128,6 +128,7 @@ def __new__( @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser = ClassificationEnsemble.add_model_specific_args(parser) parser.add_argument( "--version", type=str, @@ -166,34 +167,6 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: default=None, help="Gamma for packed wideresnet", ) - # FIXME: should be a str to choose among the available OOD criteria - # rather than a boolean, but it is not possible since - # ClassificationSingle and ClassificationEnsemble have different OOD - # criteria. - parser.add_argument( - "--entropy", - dest="use_entropy", - action=BooleanOptionalAction, - default=False, - ) - parser.add_argument( - "--logits", - dest="use_logits", - action=BooleanOptionalAction, - default=False, - ) - parser.add_argument( - "--mutual_information", - dest="use_mi", - action=BooleanOptionalAction, - default=False, - ) - parser.add_argument( - "--variation_ratio", - dest="use_variation_ratio", - action=BooleanOptionalAction, - default=False, - ) parser.add_argument( "--pretrained", dest="pretrained", From 2b3c16a8301388d8744256510e4dd915f8485790 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 31 May 2023 12:35:09 +0200 Subject: [PATCH 17/23] Fix num_estimator duplicate :bug: --- torch_uncertainty/baselines/classification/wideresnet.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index e5acd6ce..a2e4ec11 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -137,12 +137,6 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: help="Variation of WideResNet. " + f"Choose among: {cls.versions.keys()}", ) - parser.add_argument( - "--num_estimators", - type=int, - default=None, - help="Number of estimators for ensemble", - ) parser.add_argument( "--groups", type=int, From cd06cd13b4f48fb98eba03a8d2733859575ae2b2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 31 May 2023 13:29:56 +0200 Subject: [PATCH 18/23] Add groups to all networks :sparkles: Fix tests :heavy_check_mark: --- tests/baselines/test_batched.py | 1 + tests/baselines/test_packed.py | 4 ++ .../baselines/classification/resnet.py | 16 +++++- .../baselines/classification/wideresnet.py | 16 +++++- torch_uncertainty/layers/batchens_layers.py | 1 - .../layers/masksembles_layers.py | 4 +- torch_uncertainty/models/resnet/batched.py | 52 ++++++++++++++++--- torch_uncertainty/models/resnet/masked.py | 8 +-- torch_uncertainty/models/resnet/packed.py | 33 ++++++++++-- .../models/wideresnet/batched.py | 28 +++++++--- torch_uncertainty/models/wideresnet/packed.py | 15 +++++- torch_uncertainty/routines/classification.py | 21 +++++++- 12 files changed, 168 insertions(+), 31 deletions(-) diff --git a/tests/baselines/test_batched.py b/tests/baselines/test_batched.py index 8d550fc5..a5036660 100644 --- a/tests/baselines/test_batched.py +++ b/tests/baselines/test_batched.py @@ -24,6 +24,7 @@ def test_batched(self): arch=18, imagenet_structure=False, num_estimators=4, + groups=1, ) summary(net) diff --git a/tests/baselines/test_packed.py b/tests/baselines/test_packed.py index f4f843cd..759e3246 100644 --- a/tests/baselines/test_packed.py +++ b/tests/baselines/test_packed.py @@ -27,6 +27,7 @@ def test_packed(self): num_estimators=4, alpha=2, gamma=1, + groups=1, ) summary(net) @@ -48,6 +49,7 @@ def test_packed_alpha_lt_0(self): num_estimators=4, alpha=0, gamma=1, + groups=1, ) def test_packed_gamma_lt_1(self): @@ -63,6 +65,7 @@ def test_packed_gamma_lt_1(self): num_estimators=4, alpha=2, gamma=0, + groups=1, ) @@ -80,6 +83,7 @@ def test_packed(self): num_estimators=4, alpha=2, gamma=1, + groups=1, ) summary(net) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index b1dd8e77..29c592e1 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -107,16 +107,30 @@ def __new__( raise ValueError( f"Attribute `gamma` should be >= 1, not {gamma}" ) + if groups < 1: + raise ValueError( + f"Number of groups must be at least 1, not {groups}" + ) params.update( { "num_estimators": num_estimators, "alpha": alpha, "gamma": gamma, + "groups": groups, "pretrained": pretrained, } ) elif version == "batched": - params.update({"num_estimators": num_estimators}) + if groups < 1: + raise ValueError( + f"Number of groups must be at least 1, not {groups}" + ) + params.update( + { + "num_estimators": num_estimators, + "groups": groups, + } + ) elif version == "masked": # TODO: check parameters within a function if scale < 1: diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index a2e4ec11..4f83b1f6 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -72,16 +72,30 @@ def __new__( raise ValueError( f"Attribute `gamma` should be >= 1, not {gamma}" ) + if groups < 1: + raise ValueError( + f"Number of groups must be at least 1, not {groups}" + ) params.update( { "num_estimators": num_estimators, "alpha": alpha, "gamma": gamma, + "groups": groups, # "pretrained": pretrained, } ) elif version == "batched": - params.update({"num_estimators": num_estimators}) + if groups < 1: + raise ValueError( + f"Number of groups must be at least 1, not {groups}" + ) + params.update( + { + "num_estimators": num_estimators, + "groups": groups, + } + ) elif version == "masked": # TODO: check parameters within a function if scale < 1: diff --git a/torch_uncertainty/layers/batchens_layers.py b/torch_uncertainty/layers/batchens_layers.py index 526f6b5b..751b327d 100644 --- a/torch_uncertainty/layers/batchens_layers.py +++ b/torch_uncertainty/layers/batchens_layers.py @@ -296,7 +296,6 @@ def __init__( self.stride = _pair(stride) self.padding = padding if isinstance(padding, str) else _pair(padding) self.dilation = _pair(dilation) - self.groups = groups self.conv = nn.Conv2d( in_channels=in_channels, diff --git a/torch_uncertainty/layers/masksembles_layers.py b/torch_uncertainty/layers/masksembles_layers.py index eaf49373..a087b9bc 100644 --- a/torch_uncertainty/layers/masksembles_layers.py +++ b/torch_uncertainty/layers/masksembles_layers.py @@ -206,7 +206,7 @@ def __init__( self.mask = Mask1D( in_features, num_masks=num_estimators, scale=scale, **factory_kwargs ) - self.conv1x1 = nn.Linear( + self.linear = nn.Linear( in_features=in_features, out_features=out_features, bias=bias, @@ -214,7 +214,7 @@ def __init__( ) def forward(self, input: Tensor) -> Tensor: - return self.conv1x1(self.mask(input)) + return self.linear(self.mask(input)) class MaskedConv2d(nn.Module): diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index 4a0d41f6..d369b949 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -5,6 +5,8 @@ Deep Residual Learning for Image Recognition. arXiv:1512.03385 """ # fmt: off +from typing import List, Type, Union + import torch.nn as nn import torch.nn.functional as F from torch import Tensor @@ -30,6 +32,7 @@ def __init__( planes: int, stride: int = 1, num_estimators: int = 4, + groups: int = 1, ) -> None: super().__init__() self.conv1 = BatchConv2d( @@ -37,6 +40,7 @@ def __init__( planes, kernel_size=3, num_estimators=num_estimators, + groups=groups, stride=stride, padding=1, bias=False, @@ -47,6 +51,7 @@ def __init__( planes, kernel_size=3, num_estimators=num_estimators, + groups=groups, stride=1, padding=1, bias=False, @@ -59,6 +64,7 @@ def __init__( nn.Conv2d( in_planes, self.expansion * planes, + groups=groups, kernel_size=1, stride=stride, bias=False, @@ -83,6 +89,7 @@ def __init__( planes: int, stride: int = 1, num_estimators: int = 4, + groups: int = 1, ) -> None: super(Bottleneck, self).__init__() self.conv1 = BatchConv2d( @@ -90,6 +97,7 @@ def __init__( planes, kernel_size=1, num_estimators=num_estimators, + groups=groups, bias=False, ) self.bn1 = nn.BatchNorm2d(planes) @@ -98,6 +106,7 @@ def __init__( planes, kernel_size=3, num_estimators=num_estimators, + groups=groups, stride=stride, padding=1, bias=False, @@ -107,6 +116,7 @@ def __init__( planes, self.expansion * planes, num_estimators=num_estimators, + groups=groups, kernel_size=1, bias=False, ) @@ -120,6 +130,7 @@ def __init__( self.expansion * planes, kernel_size=1, num_estimators=num_estimators, + groups=groups, stride=stride, bias=False, ), @@ -138,10 +149,11 @@ def forward(self, input: Tensor) -> Tensor: class _BatchedResNet(nn.Module): def __init__( self, - block, - num_blocks, + block: Type[Union[BasicBlock, Bottleneck]], + num_blocks: List[int], in_channels: int, - num_estimators, + num_estimators: int, + groups: int = 1, num_classes=10, width_multiplier: int = 1, imagenet_structure: bool = True, @@ -158,8 +170,9 @@ def __init__( kernel_size=7, stride=2, padding=3, - bias=False, num_estimators=num_estimators, + groups=groups, + bias=False, ) else: self.conv1 = BatchConv2d( @@ -168,8 +181,9 @@ def __init__( kernel_size=3, stride=1, padding=1, - bias=False, num_estimators=num_estimators, + groups=groups, + bias=False, ) self.bn1 = nn.BatchNorm2d(64 * self.width_multiplier) @@ -186,6 +200,7 @@ def __init__( num_blocks[0], stride=1, num_estimators=num_estimators, + groups=groups, ) self.layer2 = self._make_layer( block, @@ -193,6 +208,7 @@ def __init__( num_blocks[1], stride=2, num_estimators=num_estimators, + groups=groups, ) self.layer3 = self._make_layer( block, @@ -200,6 +216,7 @@ def __init__( num_blocks[2], stride=2, num_estimators=num_estimators, + groups=groups, ) self.layer4 = self._make_layer( block, @@ -207,6 +224,7 @@ def __init__( num_blocks[3], stride=2, num_estimators=num_estimators, + groups=groups, ) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -217,11 +235,21 @@ def __init__( num_estimators=num_estimators, ) - def _make_layer(self, block, planes, num_blocks, stride, num_estimators): + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + num_blocks: int, + stride: int, + num_estimators: int, + groups: int, + ): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: - layers.append(block(self.in_planes, planes, stride, num_estimators)) + layers.append( + block(self.in_planes, planes, stride, num_estimators, groups) + ) self.in_planes = planes * block.expansion return nn.Sequential(*layers) @@ -242,6 +270,7 @@ def forward(self, x): def batched_resnet18( in_channels: int, num_estimators: int, + groups: int, num_classes: int, imagenet_structure: bool = True, ) -> _BatchedResNet: @@ -263,6 +292,7 @@ def batched_resnet18( in_channels=in_channels, num_estimators=num_estimators, num_classes=num_classes, + groups=groups, imagenet_structure=imagenet_structure, ) @@ -270,6 +300,7 @@ def batched_resnet18( def batched_resnet34( in_channels: int, num_estimators: int, + groups: int, num_classes: int, imagenet_structure: bool = True, ) -> _BatchedResNet: @@ -291,6 +322,7 @@ def batched_resnet34( in_channels=in_channels, num_estimators=num_estimators, num_classes=num_classes, + groups=groups, imagenet_structure=imagenet_structure, ) @@ -298,6 +330,7 @@ def batched_resnet34( def batched_resnet50( in_channels: int, num_estimators: int, + groups: int, num_classes: int, width_multiplier: int = 1, imagenet_structure: bool = True, @@ -321,6 +354,7 @@ def batched_resnet50( num_estimators=num_estimators, num_classes=num_classes, width_multiplier=width_multiplier, + groups=groups, imagenet_structure=imagenet_structure, ) @@ -328,6 +362,7 @@ def batched_resnet50( def batched_resnet101( in_channels: int, num_estimators: int, + groups: int, num_classes: int, imagenet_structure: bool = True, ) -> _BatchedResNet: @@ -349,6 +384,7 @@ def batched_resnet101( in_channels=in_channels, num_estimators=num_estimators, num_classes=num_classes, + groups=groups, imagenet_structure=imagenet_structure, ) @@ -356,6 +392,7 @@ def batched_resnet101( def batched_resnet152( in_channels: int, num_estimators: int, + groups: int, num_classes: int, imagenet_structure: bool = True, ) -> _BatchedResNet: @@ -379,5 +416,6 @@ def batched_resnet152( in_channels=in_channels, num_estimators=num_estimators, num_classes=num_classes, + groups=groups, imagenet_structure=imagenet_structure, ) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 88e971f1..105911e7 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -31,13 +31,13 @@ def __init__( ): super(BasicBlock, self).__init__() - # No subgroups for the first layer self.conv1 = MaskedConv2d( in_planes, planes, kernel_size=3, num_estimators=num_estimators, scale=scale, + groups=groups, stride=stride, padding=1, bias=False, @@ -94,13 +94,13 @@ def __init__( ): super(Bottleneck, self).__init__() - # No subgroups for the first layer self.conv1 = MaskedConv2d( in_planes, planes, kernel_size=1, num_estimators=num_estimators, scale=scale, + groups=groups, bias=False, ) self.bn1 = nn.BatchNorm2d(planes) @@ -182,7 +182,7 @@ def __init__( kernel_size=7, stride=2, padding=3, - groups=1, # No groups for the first layer + groups=groups, bias=False, ) else: @@ -192,7 +192,7 @@ def __init__( kernel_size=3, stride=1, padding=1, - groups=1, + groups=groups, bias=False, ) diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index 37671303..b39ce471 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -61,6 +61,7 @@ def __init__( alpha: float = 2, num_estimators: int = 4, gamma: int = 1, + groups: int = 1, ): super(BasicBlock, self).__init__() @@ -71,6 +72,7 @@ def __init__( kernel_size=3, alpha=alpha, num_estimators=num_estimators, + groups=groups, stride=stride, padding=1, bias=False, @@ -83,6 +85,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, stride=1, padding=1, bias=False, @@ -99,6 +102,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, stride=stride, bias=False, ), @@ -124,6 +128,7 @@ def __init__( alpha: float = 2, num_estimators: int = 4, gamma: int = 1, + groups: int = 1, ): super(Bottleneck, self).__init__() @@ -134,7 +139,8 @@ def __init__( kernel_size=1, alpha=alpha, num_estimators=num_estimators, - gamma=1, # No groups in the first layer + gamma=1, # No groups from gamma in the first layer + groups=groups, bias=False, ) self.bn1 = nn.BatchNorm2d(planes * alpha) @@ -147,7 +153,7 @@ def __init__( gamma=gamma, stride=stride, padding=1, - groups=1, + groups=groups, bias=False, ) self.bn2 = nn.BatchNorm2d(planes * alpha) @@ -158,6 +164,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, bias=False, ) self.bn3 = nn.BatchNorm2d(self.expansion * planes * alpha) @@ -172,6 +179,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, stride=stride, bias=False, ), @@ -197,6 +205,7 @@ def __init__( num_estimators: int, alpha: int = 2, gamma: int = 1, + groups: int = 1, imagenet_structure: bool = True, ) -> None: super().__init__() @@ -216,7 +225,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=1, # No groups for the first layer - groups=1, + groups=groups, bias=False, first=True, ) @@ -230,7 +239,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=1, # No groups for the first layer - groups=1, + groups=groups, bias=False, first=True, ) @@ -252,6 +261,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, ) self.layer2 = self._make_layer( block, @@ -261,6 +271,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, ) self.layer3 = self._make_layer( block, @@ -270,6 +281,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, ) self.layer4 = self._make_layer( block, @@ -279,6 +291,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, ) self.pool = nn.AdaptiveAvgPool2d(output_size=1) @@ -301,6 +314,7 @@ def _make_layer( alpha: float, num_estimators: int, gamma: int, + groups: int, ) -> nn.Module: strides = [stride] + [1] * (num_blocks - 1) layers = [] @@ -313,6 +327,7 @@ def _make_layer( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, ) ) self.in_planes = planes * block.expansion @@ -342,6 +357,7 @@ def packed_resnet18( alpha: int, gamma: int, num_classes: int, + groups: int, imagenet_structure: bool = True, pretrained: bool = False, ) -> _PackedResNet: @@ -365,6 +381,7 @@ def packed_resnet18( num_estimators=num_estimators, alpha=alpha, gamma=gamma, + groups=groups, num_classes=num_classes, imagenet_structure=imagenet_structure, ) @@ -382,6 +399,7 @@ def packed_resnet34( alpha: int, gamma: int, num_classes: int, + groups: int, imagenet_structure: bool = True, pretrained: bool = False, ) -> _PackedResNet: @@ -405,6 +423,7 @@ def packed_resnet34( num_estimators=num_estimators, alpha=alpha, gamma=gamma, + groups=groups, num_classes=num_classes, imagenet_structure=imagenet_structure, ) @@ -422,6 +441,7 @@ def packed_resnet50( alpha: int, gamma: int, num_classes: int, + groups: int, imagenet_structure: bool = True, pretrained: bool = False, ) -> _PackedResNet: @@ -445,6 +465,7 @@ def packed_resnet50( num_estimators=num_estimators, alpha=alpha, gamma=gamma, + groups=groups, num_classes=num_classes, imagenet_structure=imagenet_structure, ) @@ -462,6 +483,7 @@ def packed_resnet101( alpha: int, gamma: int, num_classes: int, + groups: int, imagenet_structure: bool = True, pretrained: bool = False, ) -> _PackedResNet: @@ -485,6 +507,7 @@ def packed_resnet101( num_estimators=num_estimators, alpha=alpha, gamma=gamma, + groups=groups, num_classes=num_classes, imagenet_structure=imagenet_structure, ) @@ -502,6 +525,7 @@ def packed_resnet152( alpha: int, gamma: int, num_classes: int, + groups: int, imagenet_structure: bool = True, pretrained: bool = False, ) -> _PackedResNet: @@ -527,6 +551,7 @@ def packed_resnet152( num_estimators=num_estimators, alpha=alpha, gamma=gamma, + groups=groups, num_classes=num_classes, imagenet_structure=imagenet_structure, ) diff --git a/torch_uncertainty/models/wideresnet/batched.py b/torch_uncertainty/models/wideresnet/batched.py index 42686bf8..f1e94431 100644 --- a/torch_uncertainty/models/wideresnet/batched.py +++ b/torch_uncertainty/models/wideresnet/batched.py @@ -18,6 +18,7 @@ def __init__( dropout_rate, stride=1, num_estimators=4, + groups: int = 1, ): super().__init__() self.bn1 = nn.BatchNorm2d(in_planes) @@ -26,6 +27,7 @@ def __init__( planes, kernel_size=3, num_estimators=num_estimators, + groups=groups, padding=1, bias=False, ) @@ -36,6 +38,7 @@ def __init__( planes, kernel_size=3, num_estimators=num_estimators, + groups=groups, stride=stride, padding=1, bias=False, @@ -48,6 +51,7 @@ def __init__( planes, kernel_size=1, num_estimators=num_estimators, + groups=groups, stride=stride, bias=True, ), @@ -68,9 +72,10 @@ def __init__( in_channels: int, num_classes: int, num_estimators: int, + groups: int = 1, dropout_rate: float = 0.0, imagenet_structure: bool = True, - ): + ) -> None: super().__init__() self.num_estimators = num_estimators self.in_planes = 16 @@ -86,22 +91,22 @@ def __init__( in_channels, nStages[0], num_estimators=self.num_estimators, + groups=groups, kernel_size=7, stride=2, padding=3, bias=True, - groups=1, ) else: self.conv1 = BatchConv2d( in_channels, nStages[0], num_estimators=self.num_estimators, + groups=groups, kernel_size=3, stride=1, padding=1, bias=True, - groups=1, ) if imagenet_structure: @@ -118,6 +123,7 @@ def __init__( dropout_rate, stride=1, num_estimators=self.num_estimators, + groups=groups, ) self.layer2 = self._wide_layer( WideBasicBlock, @@ -126,6 +132,7 @@ def __init__( dropout_rate, stride=2, num_estimators=self.num_estimators, + groups=groups, ) self.layer3 = self._wide_layer( WideBasicBlock, @@ -134,6 +141,7 @@ def __init__( dropout_rate, stride=2, num_estimators=self.num_estimators, + groups=groups, ) self.bn1 = nn.BatchNorm2d(nStages[3]) @@ -154,6 +162,7 @@ def _wide_layer( dropout_rate: float, stride: int, num_estimators: int, + groups: int, ): strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -161,11 +170,12 @@ def _wide_layer( for stride in strides: layers.append( block( - self.in_planes, - planes, - dropout_rate, - stride, - num_estimators, + in_planes=self.in_planes, + planes=planes, + dropout_rate=dropout_rate, + stride=stride, + num_estimators=num_estimators, + groups=groups, ) ) self.in_planes = planes @@ -191,6 +201,7 @@ def forward(self, x): def batched_wideresnet28x10( in_channels: int, num_estimators: int, + groups: int, num_classes: int, imagenet_structure: bool = True, ) -> _BatchedWide: @@ -214,5 +225,6 @@ def batched_wideresnet28x10( dropout_rate=0.3, num_classes=num_classes, num_estimators=num_estimators, + groups=groups, imagenet_structure=imagenet_structure, ) diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index 5d01b24d..4202685c 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -23,6 +23,7 @@ def __init__( alpha: float = 2, num_estimators: int = 4, gamma: int = 1, + groups: int = 1, ): super().__init__() self.bn1 = nn.BatchNorm2d(alpha * in_planes) @@ -33,6 +34,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, padding=1, bias=False, ) @@ -45,6 +47,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, stride=stride, padding=1, bias=False, @@ -59,6 +62,7 @@ def __init__( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, stride=stride, bias=True, ), @@ -81,6 +85,7 @@ def __init__( num_estimators: int = 4, alpha: int = 2, gamma: int = 1, + groups: int = 1, dropout_rate: float = 0, imagenet_structure: bool = True, ): @@ -104,7 +109,7 @@ def __init__( stride=2, padding=3, gamma=1, # No groups for the first layer - groups=1, + groups=groups, bias=True, first=True, ) @@ -118,6 +123,7 @@ def __init__( stride=1, padding=1, gamma=gamma, + groups=groups, bias=True, first=True, ) @@ -138,6 +144,7 @@ def __init__( alpha=alpha, num_estimators=self.num_estimators, gamma=gamma, + groups=groups, ) self.layer2 = self._wide_layer( WideBasicBlock, @@ -148,6 +155,7 @@ def __init__( alpha=alpha, num_estimators=self.num_estimators, gamma=gamma, + groups=groups, ) self.layer3 = self._wide_layer( WideBasicBlock, @@ -158,6 +166,7 @@ def __init__( alpha=alpha, num_estimators=self.num_estimators, gamma=gamma, + groups=groups, ) self.bn1 = nn.BatchNorm2d(nStages[3] * alpha, momentum=0.9) @@ -182,6 +191,7 @@ def _wide_layer( alpha: float, num_estimators: int, gamma: int, + groups: int, ): strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -196,6 +206,7 @@ def _wide_layer( alpha=alpha, num_estimators=num_estimators, gamma=gamma, + groups=groups, ) ) self.in_planes = planes @@ -224,6 +235,7 @@ def packed_wideresnet28x10( num_estimators: int, alpha: int, gamma: int, + groups: int, num_classes: int, imagenet_structure: bool = True, ) -> _PackedWide: @@ -251,5 +263,6 @@ def packed_wideresnet28x10( num_estimators=num_estimators, alpha=alpha, gamma=gamma, + groups=groups, imagenet_structure=imagenet_structure, ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 586314cd..031c31ac 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -408,14 +408,31 @@ def add_model_specific_args( - ``--logits``: sets :attr:`use_logits` to ``True``. - ``--mutual_information``: sets :attr:`use_mi` to ``True``. - ``--variation_ratio``: sets :attr:`use_variation_ratio` to ``True``. + - ``--num_estimators``: sets :attr:`num_estimators`. """ parent_parser = ClassificationSingle.add_model_specific_args( parent_parser ) + # FIXME: should be a str to choose among the available OOD criteria + # rather than a boolean, but it is not possible since + # ClassificationSingle and ClassificationEnsemble have different OOD + # criteria. parent_parser.add_argument( - "--mutual_information", dest="use_mi", action="store_true" + "--mutual_information", + dest="use_mi", + action="store_true", + default=False, ) parent_parser.add_argument( - "--variation_ratio", dest="use_variation_ratio", action="store_true" + "--variation_ratio", + dest="use_variation_ratio", + action="store_true", + default=False, + ) + parent_parser.add_argument( + "--num_estimators", + type=int, + default=None, + help="Number of estimators for ensemble", ) return parent_parser From 25823ae13b32c61bac823e7519eb2f0c56956fee Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 2 Jun 2023 09:04:39 +0200 Subject: [PATCH 19/23] Polish baselines and layer argument checks :hammer: --- .../baselines/classification/resnet.py | 46 +++-------------- .../baselines/classification/wideresnet.py | 49 +++---------------- .../layers/masksembles_layers.py | 6 +++ torch_uncertainty/layers/packed_layers.py | 10 ++++ 4 files changed, 29 insertions(+), 82 deletions(-) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 29c592e1..39056cee 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -88,68 +88,34 @@ def __new__( "in_channels": in_channels, "num_classes": num_classes, "imagenet_structure": imagenet_structure, + "groups": groups, } - # version specific parameters - if version == "vanilla": - # TODO: check parameters within a function - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) - params.update({"groups": groups}) - elif version == "packed": - # TODO: check parameters within a function - if alpha <= 0: - raise ValueError( - f"Attribute `alpha` should be > 0, not {alpha}" - ) - if gamma < 1: - raise ValueError( - f"Attribute `gamma` should be >= 1, not {gamma}" - ) - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) + + if version not in cls.versions.keys(): + raise ValueError(f"Unknown version: {version}") + + if version == "packed": params.update( { "num_estimators": num_estimators, "alpha": alpha, "gamma": gamma, - "groups": groups, "pretrained": pretrained, } ) elif version == "batched": - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) params.update( { "num_estimators": num_estimators, - "groups": groups, } ) elif version == "masked": - # TODO: check parameters within a function - if scale < 1: - raise ValueError( - f"Attribute `scale` should be >= 1, not {scale}." - ) - if groups < 1: - raise ValueError( - f"Attribute `groups` should be >= 1, not {groups}." - ) params.update( { "num_estimators": num_estimators, "scale": scale, - "groups": groups, } ) - else: - raise ValueError(f"Unknown version: {version}") model = cls.versions[version][cls.archs.index(arch)](**params) kwargs.update(params) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 4f83b1f6..33cc9fc6 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -45,76 +45,41 @@ def __new__( use_logits: bool = False, use_mi: bool = False, use_variation_ratio: bool = False, - pretrained: bool = False, + # pretrained: bool = False, **kwargs, ) -> LightningModule: - # FIXME: should be a function to avoid repetition params = { "in_channels": in_channels, "num_classes": num_classes, "imagenet_structure": imagenet_structure, + "groups": groups, } + + if version not in cls.versions.keys(): + raise ValueError(f"Unknown version: {version}") + # version specific params - if version == "vanilla": - # TODO: check parameters within a function - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) - params.update({"groups": groups}) - elif version == "packed": - # TODO: check parameters within a function - if alpha <= 0: - raise ValueError( - f"Attribute `alpha` should be > 0, not {alpha}" - ) - if gamma < 1: - raise ValueError( - f"Attribute `gamma` should be >= 1, not {gamma}" - ) - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) + if version == "packed": params.update( { "num_estimators": num_estimators, "alpha": alpha, "gamma": gamma, - "groups": groups, - # "pretrained": pretrained, } ) elif version == "batched": - if groups < 1: - raise ValueError( - f"Number of groups must be at least 1, not {groups}" - ) params.update( { "num_estimators": num_estimators, - "groups": groups, } ) elif version == "masked": - # TODO: check parameters within a function - if scale < 1: - raise ValueError( - f"Attribute `scale` should be >= 1, not {scale}." - ) - if groups < 1: - raise ValueError( - f"Attribute `groups` should be >= 1, not {groups}." - ) params.update( { "num_estimators": num_estimators, "scale": scale, - "groups": groups, } ) - else: - raise ValueError(f"Unknown version: {version}") model = cls.versions[version][0](**params) kwargs.update(params) diff --git a/torch_uncertainty/layers/masksembles_layers.py b/torch_uncertainty/layers/masksembles_layers.py index a087b9bc..7ebff398 100644 --- a/torch_uncertainty/layers/masksembles_layers.py +++ b/torch_uncertainty/layers/masksembles_layers.py @@ -203,6 +203,9 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + if scale < 1: + raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") + self.mask = Mask1D( in_features, num_masks=num_estimators, scale=scale, **factory_kwargs ) @@ -264,6 +267,9 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + if scale < 1: + raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") + self.mask = Mask2D( in_channels, num_masks=num_estimators, scale=scale, **factory_kwargs ) diff --git a/torch_uncertainty/layers/packed_layers.py b/torch_uncertainty/layers/packed_layers.py index acde874c..2c635288 100644 --- a/torch_uncertainty/layers/packed_layers.py +++ b/torch_uncertainty/layers/packed_layers.py @@ -60,6 +60,11 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + if alpha <= 0: + raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") + if gamma <= 0: + raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") + self.num_estimators = num_estimators self.rearrange = rearrange @@ -164,6 +169,11 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + if alpha <= 0: + raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") + if gamma <= 0: + raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") + self.num_estimators = num_estimators # Define the number of channels of the underlying convolution From bbee3bd9a40d987590a6e18642a4cfcb69232985 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 2 Jun 2023 09:44:34 +0200 Subject: [PATCH 20/23] Add docstrings to baselines :bulb: --- .../baselines/classification/resnet.py | 74 +++++++++++++++++-- .../baselines/classification/wideresnet.py | 55 ++++++++++++++ 2 files changed, 124 insertions(+), 5 deletions(-) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 39056cee..3c22cdd7 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -35,6 +35,70 @@ # fmt: on class ResNet: + r"""ResNet backbone baseline for classification providing support for + various versions and architectures. + + Args: + num_classes (int): Number of classes to predict. + in_channels (int): Number of input channels. + loss (nn.Module): Training loss. + optimization_procedure (Any): Optimization procedure, corresponds to + what expect the `LightningModule.configure_optimizers() + `_ + method. + version (str): + Determines which ResNet version to use: + + - ``"vanilla"``: original ResNet + - ``"packed"``: Packed-Ensembles ResNet + - ``"batched"``: BatchEnsemble ResNet + - ``"masked"``: Masksemble ResNet + + arch (int): + Determines which ResNet architecture to use: + + - ``18``: ResNet-18 + - ``32``: ResNet-32 + - ``50``: ResNet-50 + - ``101``: ResNet-101 + - ``152``: ResNet-152 + + imagenet_structure (bool, optional): Whether to use the ImageNet + structure. Defaults to ``True``. + num_estimators (int, optional): Number of estimators in the ensemble. + Only used if :attr:`version` is either ``"packed"``, ``"batched"`` + or ``"masked"`` Defaults to ``None``. + groups (int, optional): Number of groups in convolutions. Defaults to + ``1``. + scale (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"masked"``. Defaults + to ``None``. + alpha (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"packed"``. Defaults + to ``None``. + gamma (int, optional): Number of groups within each estimator. Only + used if :attr:`version` is ``"packed"`` and scales with + :attr:`groups`. Defaults to ``1s``. + use_entropy (bool, optional): Indicates whether to use the entropy + values as the OOD criterion or not. Defaults to ``False``. + use_logits (bool, optional): Indicates whether to use the logits as the + OOD criterion or not. Defaults to ``False``. + use_mi (bool, optional): Indicates whether to use the mutual + information as the OOD criterion or not. Defaults to ``False``. + use_variation_ratio (bool, optional): Indicates whether to use the + variation ratio as the OOD criterion or not. Defaults to ``False``. + pretrained (bool, optional): Indicates whether to use the pretrained + weights or not. Only used if :attr:`version` is ``"packed"``. + Defaults to ``False``. + + Raises: + ValueError: If :attr:`version` is not either ``"vanilla"``, + ``"packed"``, ``"batched"`` or ``"masked"``. + + Returns: + LightningModule: ResNet baseline ready for training and evaluation. + """ + single = ["vanilla"] ensemble = ["packed", "batched", "masked"] versions = { @@ -73,10 +137,10 @@ def __new__( arch: int, imagenet_structure: bool = True, num_estimators: Optional[int] = None, - groups: Optional[int] = None, + groups: Optional[int] = 1, scale: Optional[float] = None, - alpha: Optional[int] = None, - gamma: Optional[int] = None, + alpha: Optional[float] = None, + gamma: Optional[int] = 1, use_entropy: bool = False, use_logits: bool = False, use_mi: bool = False, @@ -172,14 +236,14 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: ) parser.add_argument( "--alpha", - type=int, + type=float, default=None, help="Alpha for packed resnet", ) parser.add_argument( "--gamma", type=int, - default=None, + default=1, help="Gamma for packed resnet", ) parser.add_argument( diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index 33cc9fc6..b5a5fdf0 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -19,6 +19,61 @@ # fmt: on class WideResNet: + r"""Wide-ResNet28x10 backbone baseline for classification providing support + for various versions. + + Args: + num_classes (int): Number of classes to predict. + in_channels (int): Number of input channels. + loss (nn.Module): Training loss. + optimization_procedure (Any): Optimization procedure, corresponds to + what expect the `LightningModule.configure_optimizers() + `_ + method. + version (str): + Determines which Wide-ResNet version to use: + + - ``"vanilla"``: original Wide-ResNet + - ``"packed"``: Packed-Ensembles Wide-ResNet + - ``"batched"``: BatchEnsemble Wide-ResNet + - ``"masked"``: Masksemble Wide-ResNet + + imagenet_structure (bool, optional): Whether to use the ImageNet + structure. Defaults to ``True``. + num_estimators (int, optional): Number of estimators in the ensemble. + Only used if :attr:`version` is either ``"packed"``, ``"batched"`` + or ``"masked"`` Defaults to ``None``. + groups (int, optional): Number of groups in convolutions. Defaults to + ``1``. + scale (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"masked"``. Defaults + to ``None``. + alpha (float, optional): Expansion factor affecting the width of the + estimators. Only used if :attr:`version` is ``"packed"``. Defaults + to ``None``. + gamma (int, optional): Number of groups within each estimator. Only + used if :attr:`version` is ``"packed"`` and scales with + :attr:`groups`. Defaults to ``1s``. + use_entropy (bool, optional): Indicates whether to use the entropy + values as the OOD criterion or not. Defaults to ``False``. + use_logits (bool, optional): Indicates whether to use the logits as the + OOD criterion or not. Defaults to ``False``. + use_mi (bool, optional): Indicates whether to use the mutual + information as the OOD criterion or not. Defaults to ``False``. + use_variation_ratio (bool, optional): Indicates whether to use the + variation ratio as the OOD criterion or not. Defaults to ``False``. + pretrained (bool, optional): Indicates whether to use the pretrained + weights or not. Only used if :attr:`version` is ``"packed"``. + Defaults to ``False``. + + Raises: + ValueError: If :attr:`version` is not either ``"vanilla"``, + ``"packed"``, ``"batched"`` or ``"masked"``. + + Returns: + LightningModule: Wide-ResNet baseline ready for training and + evaluation. + """ single = ["vanilla"] ensemble = ["packed", "batched", "masked"] versions = { From c0c44815057e54bb84616582317a94784646cbba Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 2 Jun 2023 10:12:43 +0200 Subject: [PATCH 21/23] Simplify parser arguments for baselines :hammer: --- .../baselines/classification/resnet.py | 43 ++++---------- .../baselines/classification/wideresnet.py | 36 ++++-------- .../baselines/utils/parser_addons.py | 56 +++++++++++++++++++ 3 files changed, 76 insertions(+), 59 deletions(-) create mode 100644 torch_uncertainty/baselines/utils/parser_addons.py diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 3c22cdd7..c90fdbe5 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from pytorch_lightning import LightningModule -from torch_uncertainty.models.resnet import ( +from ...models.resnet import ( batched_resnet18, batched_resnet34, batched_resnet50, @@ -27,10 +27,15 @@ resnet101, resnet152, ) -from torch_uncertainty.routines.classification import ( +from ...routines.classification import ( ClassificationEnsemble, ClassificationSingle, ) +from ..utils.parser_addons import ( + add_masked_specific_args, + add_packed_specific_args, + add_resnet_specific_args, +) # fmt: on @@ -208,6 +213,9 @@ def __new__( @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: parser = ClassificationEnsemble.add_model_specific_args(parser) + parser = add_resnet_specific_args(parser) + parser = add_packed_specific_args(parser) + parser = add_masked_specific_args(parser) parser.add_argument( "--version", type=str, @@ -215,37 +223,6 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: default="vanilla", help=f"Variation of ResNet. Choose among: {cls.versions.keys()}", ) - parser.add_argument( - "--arch", - type=int, - choices=cls.archs, - default=18, - help=f"Architecture of ResNet. Choose among: {cls.archs}", - ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups for vanilla or masked resnet", - ) - parser.add_argument( - "--scale", - type=float, - default=None, - help="Scale for masked resnet", - ) - parser.add_argument( - "--alpha", - type=float, - default=None, - help="Alpha for packed resnet", - ) - parser.add_argument( - "--gamma", - type=int, - default=1, - help="Gamma for packed resnet", - ) parser.add_argument( "--pretrained", dest="pretrained", diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index b5a5fdf0..64c82d03 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -5,16 +5,21 @@ import torch.nn as nn from pytorch_lightning import LightningModule -from torch_uncertainty.models.wideresnet import ( +from ...models.wideresnet import ( batched_wideresnet28x10, masked_wideresnet28x10, packed_wideresnet28x10, wideresnet28x10, ) -from torch_uncertainty.routines.classification import ( +from ...routines.classification import ( ClassificationEnsemble, ClassificationSingle, ) +from ..utils.parser_addons import ( + add_masked_specific_args, + add_packed_specific_args, + add_wideresnet_specific_args, +) # fmt: on @@ -163,6 +168,9 @@ def __new__( @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: parser = ClassificationEnsemble.add_model_specific_args(parser) + parser = add_wideresnet_specific_args(parser) + parser = add_packed_specific_args(parser) + parser = add_masked_specific_args(parser) parser.add_argument( "--version", type=str, @@ -171,30 +179,6 @@ def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: help="Variation of WideResNet. " + f"Choose among: {cls.versions.keys()}", ) - parser.add_argument( - "--groups", - type=int, - default=1, - help="Number of groups for vanilla or masked wideresnet", - ) - parser.add_argument( - "--scale", - type=float, - default=None, - help="Scale for masked wideresnet", - ) - parser.add_argument( - "--alpha", - type=int, - default=None, - help="Alpha for packed wideresnet", - ) - parser.add_argument( - "--gamma", - type=int, - default=None, - help="Gamma for packed wideresnet", - ) parser.add_argument( "--pretrained", dest="pretrained", diff --git a/torch_uncertainty/baselines/utils/parser_addons.py b/torch_uncertainty/baselines/utils/parser_addons.py new file mode 100644 index 00000000..e157812b --- /dev/null +++ b/torch_uncertainty/baselines/utils/parser_addons.py @@ -0,0 +1,56 @@ +# fmt: off +from argparse import ArgumentParser + + +# fmt: on +def add_resnet_specific_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--arch", + type=int, + choices=[18, 34, 50, 101, 152], + default=18, + help=f"Architecture of ResNet. Choose among: {[18, 34, 50, 101, 152]}", + ) + parser.add_argument( + "--groups", + type=int, + default=1, + help="Number of groups", + ) + return parser + + +def add_wideresnet_specific_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--groups", + type=int, + default=1, + help="Number of groups", + ) + return parser + + +def add_packed_specific_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--alpha", + type=float, + default=None, + help="Alpha for Packed-Ensembles", + ) + parser.add_argument( + "--gamma", + type=int, + default=1, + help="Gamma for Packed-Ensembles", + ) + return parser + + +def add_masked_specific_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--scale", + type=float, + default=None, + help="Scale for Masksembles", + ) + return parser From 945cbe55674986919c74c2616a42fe9686df2ea3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 2 Jun 2023 11:22:25 +0200 Subject: [PATCH 22/23] Solve review comments :ok_hand: --- experiments/readme.md | 2 +- torch_uncertainty/layers/packed_layers.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/experiments/readme.md b/experiments/readme.md index 05bf5fc5..f303a316 100644 --- a/experiments/readme.md +++ b/experiments/readme.md @@ -4,4 +4,4 @@ Torch-Uncertainty proposes various benchmarks to evaluate the uncertainty estima ## Classification -*Work in progress* \ No newline at end of file +*Work in progress* diff --git a/torch_uncertainty/layers/packed_layers.py b/torch_uncertainty/layers/packed_layers.py index 2c635288..05da674f 100644 --- a/torch_uncertainty/layers/packed_layers.py +++ b/torch_uncertainty/layers/packed_layers.py @@ -62,6 +62,11 @@ def __init__( if alpha <= 0: raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") + + if not isinstance(gamma, int): + raise ValueError( + f"Attribute `gamma` should be an int, not " f"{type(gamma)}" + ) if gamma <= 0: raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") From e93ce38b4d6df0bbfecc3e449a18c2cd61461309 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 2 Jun 2023 11:24:24 +0200 Subject: [PATCH 23/23] Add forgotten consistency check :ok_hand: --- torch_uncertainty/layers/packed_layers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_uncertainty/layers/packed_layers.py b/torch_uncertainty/layers/packed_layers.py index 05da674f..fe0a031d 100644 --- a/torch_uncertainty/layers/packed_layers.py +++ b/torch_uncertainty/layers/packed_layers.py @@ -176,6 +176,11 @@ def __init__( if alpha <= 0: raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") + + if not isinstance(gamma, int): + raise ValueError( + f"Attribute `gamma` should be an int, not " f"{type(gamma)}" + ) if gamma <= 0: raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}")