diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index 084adb66..f9fc48f3 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -31,7 +31,7 @@ jobs: echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" - name: Cache folder for TorchUncertainty - uses: actions/cache@v3 + uses: actions/cache@v4 id: cache-folder with: path: | @@ -41,7 +41,7 @@ jobs: - name: Install dependencies run: | python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - python3 -m pip install .[image,dev,docs] + python3 -m pip install .[all] - name: Sphinx build if: github.event.pull_request.draft == false diff --git a/.gitignore b/.gitignore index e8f10e61..659ceed3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ data/ logs/ lightning_logs/ +auto_tutorials_source/*.png docs/*/generated/ docs/*/auto_tutorials/ *.pth diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 939e83c1..c32c991d 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -27,7 +27,7 @@ To train a BNN using TorchUncertainty, we have to load the following modules: -- the Trainer from Lightning +- our TUTrainer - the model: bayesian_lenet, which lies in the torch_uncertainty.model - the classification training routine from torch_uncertainty.routines - the Bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file @@ -39,9 +39,9 @@ # %% from pathlib import Path -from lightning.pytorch import Trainer from torch import nn, optim +from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.losses import ELBOLoss from torch_uncertainty.models.lenet import bayesian_lenet @@ -65,12 +65,12 @@ def optim_lenet(model: nn.Module): # 3. Creating the necessary variables # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In the following, we define the Lightning trainer, the root of the datasets and the logs. +# In the following, we instantiate our trainer, define the root of the datasets and the logs. # We also create the datamodule that handles the MNIST dataset, dataloaders and transforms. # Please note that the datamodules can also handle OOD detection by setting the eval_ood # parameter to True. Finally, we create the model using the blueprint from torch_uncertainty.models. -trainer = Trainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1) +trainer = TUTrainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1) # datamodule root = Path("data") @@ -111,7 +111,7 @@ def optim_lenet(model: nn.Module): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Now that we have prepared all of this, we just have to gather everything in -# the main function and to train the model using the Lightning Trainer. +# the main function and to train the model using our wrapper of Lightning Trainer. # Specifically, it needs the routine, that includes the model as well as the # training/eval logic and the datamodule # The dataset will be downloaded automatically in the root/data folder, and the diff --git a/auto_tutorials_source/tutorial_corruption.py b/auto_tutorials_source/tutorial_corruption.py index 4eecd6ff..734e957c 100644 --- a/auto_tutorials_source/tutorial_corruption.py +++ b/auto_tutorials_source/tutorial_corruption.py @@ -12,23 +12,35 @@ torchvision and matplotlib. """ # %% -from torchvision.datasets import CIFAR10 -from torchvision.transforms import Compose, ToTensor, Resize +from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop import matplotlib.pyplot as plt +from PIL import Image +from urllib import request -ds = CIFAR10("./data", train=False, download=True) +urls = [ + "https://upload.wikimedia.org/wikipedia/commons/d/d9/Carduelis_tristis_-Michigan%2C_USA_-male-8.jpg", + "https://upload.wikimedia.org/wikipedia/commons/5/5d/Border_Collie_Blanca_y_Negra_Hembra_%28Belen%2C_Border_Collie_Los_Baganes%29.png", + "https://upload.wikimedia.org/wikipedia/commons/f/f8/Birmakatze_Seal-Point.jpg", + "https://upload.wikimedia.org/wikipedia/commons/a/a9/Garranos_fight.jpg", + "https://upload.wikimedia.org/wikipedia/commons/8/8b/Cottontail_Rabbit.jpg", +] + +def download_img(url, i): + request.urlretrieve(url, f"tmp_{i}.png") + return Image.open(f"tmp_{i}.png").convert('RGB') + +images_ds = [download_img(url, i) for i, url in enumerate(urls)] def get_images(main_corruption, index: int = 0): """Create an image showing the 6 levels of corruption of a given transform.""" images = [] for severity in range(6): - ds_transforms = Compose( - [ToTensor(), main_corruption(severity), Resize(256, antialias=True)] + transforms = Compose( + [Resize(256, antialias=True), CenterCrop(256), ToTensor(), main_corruption(severity), CenterCrop(224)] ) - ds = CIFAR10("./data", train=False, download=False, transform=ds_transforms) - images.append(ds[index][0].permute(1, 2, 0).numpy()) + images.append(transforms(images_ds[index]).permute(1, 2, 0).numpy()) return images @@ -65,7 +77,6 @@ def show_images(transforms): GaussianNoise, ShotNoise, ImpulseNoise, - SpeckleNoise, ) show_images( @@ -73,7 +84,6 @@ def show_images(transforms): GaussianNoise, ShotNoise, ImpulseNoise, - SpeckleNoise, ] ) @@ -81,33 +91,71 @@ def show_images(transforms): # 2. Blur Corruptions # ~~~~~~~~~~~~~~~~~~~~ from torch_uncertainty.transforms.corruption import ( - GaussianBlur, + MotionBlur, GlassBlur, DefocusBlur, + ZoomBlur, ) show_images( [ - GaussianBlur, GlassBlur, + MotionBlur, DefocusBlur, + ZoomBlur, ] ) # %% -# 3. Other Corruptions -# ~~~~~~~~~~~~~~~~~~~~ +# 3. Weather Corruptions +# ~~~~~~~~~~~~~~~~~~~~~~ from torch_uncertainty.transforms.corruption import ( - JPEGCompression, - Pixelate, Frost, + Snow, + Fog, ) show_images( [ + Fog, + Frost, + Snow, + ] +) + +# %% +# 4. Other Corruptions + +from torch_uncertainty.transforms.corruption import ( + Brightness, Contrast, Elastic, JPEGCompression, Pixelate) + +show_images( + [ + Brightness, + Contrast, JPEGCompression, Pixelate, - Frost, + Elastic, + ] +) + +# %% +# 5. Unused Corruptions +# ~~~~~~~~~~~~~~~~~~~~~ + +# The following corruptions are not used in the paper Benchmarking Neural Network Robustness to Common Corruptions and Perturbations. + +from torch_uncertainty.transforms.corruption import ( + GaussianBlur, + SpeckleNoise, + Saturation, +) + +show_images( + [ + GaussianBlur, + SpeckleNoise, + Saturation, ] ) diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index a30b49d5..24ee20a3 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -21,7 +21,7 @@ To train a MLP with the DER loss function using TorchUncertainty, we have to load the following modules: -- the Trainer from Lightning +- our TUTrainer - the model: mlp from torch_uncertainty.models.mlp - the regression training routine from torch_uncertainty.routines - the evidential objective: the DERLoss from torch_uncertainty.losses. This loss contains the classic NLL loss and a regularization term. @@ -31,10 +31,10 @@ """ # %% import torch -from lightning.pytorch import Trainer from lightning import LightningDataModule from torch import nn, optim +from torch_uncertainty import TUTrainer from torch_uncertainty.models.mlp import mlp from torch_uncertainty.datasets.regression.toy import Cubic from torch_uncertainty.losses import DERLoss @@ -67,7 +67,7 @@ def optim_regression( # Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model # as the parameters of a Normal Inverse Gamma distribution. -trainer = Trainer(accelerator="cpu", max_epochs=50) #, enable_progress_bar=False) +trainer = TUTrainer(accelerator="cpu", max_epochs=50) #, enable_progress_bar=False) # dataset train_ds = Cubic(num_samples=1000) diff --git a/auto_tutorials_source/tutorial_evidential_classification.py b/auto_tutorials_source/tutorial_evidential_classification.py index cd124f5d..babf2a73 100644 --- a/auto_tutorials_source/tutorial_evidential_classification.py +++ b/auto_tutorials_source/tutorial_evidential_classification.py @@ -16,7 +16,7 @@ To train a LeNet with the DEC loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty: -- the Trainer from Lightning +- our wrapper of the Lightning Trainer - the model: LeNet, which lies in torch_uncertainty.models - the classification training routine in the torch_uncertainty.routines - the evidential objective: the DECLoss from torch_uncertainty.losses @@ -28,9 +28,9 @@ from pathlib import Path import torch -from lightning.pytorch import Trainer from torch import nn, optim +from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.losses import DECLoss from torch_uncertainty.models.lenet import lenet @@ -53,10 +53,9 @@ def optim_lenet(model: nn.Module) -> dict: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # In the following, we need to define the root of the logs, and to -# fake-parse the arguments needed for using the PyTorch Lightning Trainer. We -# also use the same MNIST classification example as that used in the +# We use the same MNIST classification example as that used in the # original DEC paper. We only train for 3 epochs for the sake of time. -trainer = Trainer(accelerator="cpu", max_epochs=3, enable_progress_bar=False) +trainer = TUTrainer(accelerator="cpu", max_epochs=3, enable_progress_bar=False) # datamodule root = Path() / "data" diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py index 55de3735..fd60e3f1 100644 --- a/auto_tutorials_source/tutorial_from_de_to_pe.py +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -149,7 +149,7 @@ def optim_recipe(model, lr_mult: float = 1.0): from torch_uncertainty.routines import ClassificationRoutine -from torch_uncertainty.utils import TUTrainer +from torch_uncertainty import TUTrainer # Create the trainer that will handle the training trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) @@ -242,7 +242,7 @@ def optim_recipe(model, lr_mult: float = 1.0): # We have put the pre-trained models on Hugging Face that you can download with the utility function # "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not # comparable to the all the other models trained in this notebook. The pretrained models can be seen -# on `HuggingFace `_ and TorchUncertainty's are `here `_. +# on `HuggingFace `_ and TorchUncertainty's are `there `_. from torch_uncertainty.utils.hub import hf_hub_download diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 886ed9cf..bb726902 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -13,7 +13,7 @@ First, we have to load the following utilities from TorchUncertainty: -- the Trainer from Lightning +- the TUTrainer from our framework - the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules - the model: LeNet, which lies in torch_uncertainty.models - the MC Batch Normalization wrapper: mc_batch_norm, which lies in torch_uncertainty.post_processing @@ -25,9 +25,9 @@ # %% from pathlib import Path -from lightning import Trainer from torch import nn +from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.models.lenet import lenet from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 @@ -41,7 +41,7 @@ # logs. We also create the datamodule that handles the MNIST dataset # dataloaders and transforms. -trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) +trainer = TUTrainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule root = Path("data") diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index b8f01fb0..d17df7c2 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -31,7 +31,7 @@ # %% from pathlib import Path -from torch_uncertainty.utils import TUTrainer +from torch_uncertainty import TUTrainer from torch import nn from torch_uncertainty.datamodules import MNISTDataModule diff --git a/docs/source/api.rst b/docs/source/api.rst index cd778cb2..4f3e0177 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -320,6 +320,12 @@ Losses ELBOLoss BetaNLL DECLoss + DERLoss + FocalLoss + ConflictualLoss + ConfidencePenaltyLoss + KLDiv + ELBOLoss Post-Processing Methods ----------------------- diff --git a/docs/source/cli_guide.rst b/docs/source/cli_guide.rst index 0b888ea9..07904b98 100644 --- a/docs/source/cli_guide.rst +++ b/docs/source/cli_guide.rst @@ -22,7 +22,7 @@ Let's see how to implement the CLI, by checking out the ``experiments/classifica from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule - from torch_uncertainty.utils import TULightningCLI + from torch_uncertainty import TULightningCLI class ResNetCLI(TULightningCLI): diff --git a/docs/source/conf.py b/docs/source/conf.py index c5d3676f..2b6426ea 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.2.2.post2" +release = "0.3.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index b81eb766..42d4dfb7 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -86,16 +86,16 @@ CIFAR10 datamodule. .. code:: python from torch_uncertainty.datamodules import CIFAR10DataModule - from lightning.pytorch import Trainer + from lightning.pytorch import TUTrainer dm = CIFAR10DataModule(root="data", batch_size=32) - trainer = Trainer(gpus=1, max_epochs=100) + trainer = TUTTrainer(gpus=1, max_epochs=100) trainer.fit(routine, dm) trainer.test(routine, dm) Here it is, you have trained your first model with TorchUncertainty! As a result, you will get access to various metrics measuring the ability of your model to handle uncertainty. You can get other examples of training with lightning Trainers -looking at the `Tutorials `_. +looking at the `Tutorials `_. More metrics ^^^^^^^^^^^^ diff --git a/docs/source/references.rst b/docs/source/references.rst index 5e6c7425..eb72cbee 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -246,6 +246,16 @@ For Laplace Approximation, consider citing: Losses ------ +Focal Loss +^^^^^^^^^^ + +For the focal loss, consider citing: + +**Focal Loss for Dense Object Detection** + +* Authors: *Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, and Piotr Dollár* +* Paper: `TPAMI 2020 `__. + Conflictual Loss ^^^^^^^^^^^^^^^^ diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 6deddd4c..de03521d 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.utils import TULightningCLI class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.SGD) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) diff --git a/experiments/classification/cifar10/vgg.py b/experiments/classification/cifar10/vgg.py index a4614e3a..0f40d498 100644 --- a/experiments/classification/cifar10/vgg.py +++ b/experiments/classification/cifar10/vgg.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.utils import TULightningCLI class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) diff --git a/experiments/classification/cifar10/wideresnet.py b/experiments/classification/cifar10/wideresnet.py index 03870002..bf11e2d2 100644 --- a/experiments/classification/cifar10/wideresnet.py +++ b/experiments/classification/cifar10/wideresnet.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.utils import TULightningCLI class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.SGD) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index 15fb4eae..e62de94f 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -30,7 +30,7 @@ model: num_classes: 100 in_channels: 3 loss: CrossEntropyLoss - version: standard + version: std arch: 18 style: cifar data: diff --git a/experiments/classification/cifar100/resnet.py b/experiments/classification/cifar100/resnet.py index 0c3a0068..8c8b8a00 100644 --- a/experiments/classification/cifar100/resnet.py +++ b/experiments/classification/cifar100/resnet.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.utils import TULightningCLI class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.SGD) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) diff --git a/experiments/classification/cifar100/vgg.py b/experiments/classification/cifar100/vgg.py index 1936f809..af07c997 100644 --- a/experiments/classification/cifar100/vgg.py +++ b/experiments/classification/cifar100/vgg.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.classification import VGGBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.utils import TULightningCLI class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) diff --git a/experiments/classification/cifar100/wideresnet.py b/experiments/classification/cifar100/wideresnet.py index 49b9a227..f29ad2ff 100644 --- a/experiments/classification/cifar100/wideresnet.py +++ b/experiments/classification/cifar100/wideresnet.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.classification import WideResNetBaseline from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.utils import TULightningCLI class ResNetCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.SGD) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR) diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index c93610f4..24e5b203 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -1,13 +1,14 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.routines import ClassificationRoutine -from torch_uncertainty.utils import TULightningCLI class MNISTCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) parser.add_optimizer_args(torch.optim.SGD) diff --git a/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml new file mode 100644 index 00000000..330181df --- /dev/null +++ b/experiments/classification/tiny-imagenet/configs/resnet18/standard.yaml @@ -0,0 +1,45 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/resnet18 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: std + arch: 18 + style: cifar +data: + root: ./data + batch_size: 256 +optimizer: + lr: 0.2 + momentum: 0.9 + weight_decay: 1e-4 +lr_scheduler: + eta_min: 0.0 + T_max: 200 diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index e003ae84..1959cf6c 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -1,86 +1,28 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.baselines import ResNetBaseline +from torch_uncertainty import TULightningCLI +from torch_uncertainty.baselines.classification import ResNetBaseline from torch_uncertainty.datamodules import TinyImageNetDataModule -from torch_uncertainty.optim_recipes import get_procedure -from torch_uncertainty.utils import csv_writer - - -def optim_tiny(model: nn.Module) -> dict: - optimizer = optim.SGD( - model.parameters(), lr=0.2, weight_decay=1e-4, momentum=0.9 - ) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, - eta_min=0, - T_max=200, - ) - return {"optimizer": optimizer, "lr_scheduler": scheduler} - -if __name__ == "__main__": - args = init_args(ResNetBaseline, TinyImageNetDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - if args.exp_name == "": - args.exp_name = f"{args.version}-resnet{args.arch}-tinyimagenet" - # datamodule - args.root = str(root / "data") - dm = TinyImageNetDataModule(**vars(args)) +class ResNetCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.CosineAnnealingLR) - if args.opt_temp_scaling: - calibration_set = dm.get_test_set - elif args.val_temp_scaling: - calibration_set = dm.get_val_set - else: - calibration_set = None - if args.use_cv: - list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) - list_model = [ - ResNetBaseline( - num_classes=list_dm[i].dm.num_classes, - in_channels=list_dm[i].dm.num_channels, - loss=nn.CrossEntropyLoss(), - optim_recipe=get_procedure( - f"resnet{args.arch}", "tiny-imagenet", args.version - ), - style="cifar", - calibration_set=calibration_set, - **vars(args), - ) - for i in range(len(list_dm)) - ] +def cli_main() -> ResNetCLI: + return ResNetCLI(ResNetBaseline, TinyImageNetDataModule) - results = cli_main( - list_model, list_dm, args.exp_dir, args.exp_name, args - ) - else: - # model - model = ResNetBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss(), - optim_recipe=get_procedure( - f"resnet{args.arch}", "tiny-imagenet", args.version - ), - calibration_set=calibration_set, - style="cifar", - **vars(args), - ) - results = cli_main(model, dm, args.exp_dir, args.exp_name, args) - - if results is not None: - for dict_result in results: - csv_writer( - Path(args.exp_dir) / Path(args.exp_name) / "results.csv", - dict_result, - ) +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/depth/kitti/bts.py b/experiments/depth/kitti/bts.py index 456784e3..3fdfae82 100644 --- a/experiments/depth/kitti/bts.py +++ b/experiments/depth/kitti/bts.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.depth import BTSBaseline from torch_uncertainty.datamodules.depth import KITTIDataModule -from torch_uncertainty.utils import TULightningCLI from torch_uncertainty.utils.learning_rate import PolyLR diff --git a/experiments/depth/nyu/bts.py b/experiments/depth/nyu/bts.py index 20cc0330..d8aac63b 100644 --- a/experiments/depth/nyu/bts.py +++ b/experiments/depth/nyu/bts.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.depth import BTSBaseline from torch_uncertainty.datamodules.depth import NYUv2DataModule -from torch_uncertainty.utils import TULightningCLI from torch_uncertainty.utils.learning_rate import PolyLR diff --git a/experiments/regression/uci_datasets/mlp.py b/experiments/regression/uci_datasets/mlp.py index a0605472..7c187673 100644 --- a/experiments/regression/uci_datasets/mlp.py +++ b/experiments/regression/uci_datasets/mlp.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.regression import MLPBaseline from torch_uncertainty.datamodules import UCIDataModule -from torch_uncertainty.utils import TULightningCLI class MLPCLI(TULightningCLI): diff --git a/experiments/segmentation/camvid/configs/deeplab.yaml b/experiments/segmentation/camvid/configs/deeplab.yaml new file mode 100644 index 00000000..d8bcdc4c --- /dev/null +++ b/experiments/segmentation/camvid/configs/deeplab.yaml @@ -0,0 +1,42 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 120 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/deeplab + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/seg/mIoU + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +model: + num_classes: 11 + loss: CrossEntropyLoss + version: std + arch: 50 + style: v3+ + output_stride: 16 + separable: false +data: + root: ./data + batch_size: 8 + num_workers: 8 +optimizer: + lr: 0.002 + weight_decay: 1e-4 + momentum: 0.9 +lr_scheduler: + power: 1.0 + total_iters: 120 diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 0dfac0a0..96f200b3 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -4,8 +4,23 @@ seed_everything: false trainer: accelerator: gpu devices: 1 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/segformer + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/seg/mIoU + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: - num_classes: 12 + num_classes: 11 loss: CrossEntropyLoss version: std arch: 0 diff --git a/experiments/segmentation/camvid/deeplab.py b/experiments/segmentation/camvid/deeplab.py new file mode 100644 index 00000000..efebd51b --- /dev/null +++ b/experiments/segmentation/camvid/deeplab.py @@ -0,0 +1,28 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty import TULightningCLI +from torch_uncertainty.baselines.segmentation import DeepLabBaseline +from torch_uncertainty.datamodules.segmentation import CamVidDataModule +from torch_uncertainty.utils.learning_rate import PolyLR + + +class DeepLabV3CLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(PolyLR) + + +def cli_main() -> DeepLabV3CLI: + return DeepLabV3CLI(DeepLabBaseline, CamVidDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/segmentation/camvid/segformer.py b/experiments/segmentation/camvid/segformer.py index 8eecfb50..537ccf64 100644 --- a/experiments/segmentation/camvid/segformer.py +++ b/experiments/segmentation/camvid/segformer.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.segmentation import SegFormerBaseline from torch_uncertainty.datamodules.segmentation import CamVidDataModule -from torch_uncertainty.utils import TULightningCLI class SegFormerCLI(TULightningCLI): diff --git a/experiments/segmentation/cityscapes/deeplab.py b/experiments/segmentation/cityscapes/deeplab.py index ce064b05..0b074a74 100644 --- a/experiments/segmentation/cityscapes/deeplab.py +++ b/experiments/segmentation/cityscapes/deeplab.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.segmentation import DeepLabBaseline from torch_uncertainty.datamodules.segmentation import CityscapesDataModule -from torch_uncertainty.utils import TULightningCLI from torch_uncertainty.utils.learning_rate import PolyLR diff --git a/experiments/segmentation/cityscapes/segformer.py b/experiments/segmentation/cityscapes/segformer.py index 2b7fe992..6fb976bb 100644 --- a/experiments/segmentation/cityscapes/segformer.py +++ b/experiments/segmentation/cityscapes/segformer.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.segmentation import SegFormerBaseline from torch_uncertainty.datamodules.segmentation import CityscapesDataModule -from torch_uncertainty.utils import TULightningCLI class SegFormerCLI(TULightningCLI): diff --git a/experiments/segmentation/muad/segformer.py b/experiments/segmentation/muad/segformer.py index 67ad9564..3feb6271 100644 --- a/experiments/segmentation/muad/segformer.py +++ b/experiments/segmentation/muad/segformer.py @@ -1,9 +1,9 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.segmentation import SegFormerBaseline from torch_uncertainty.datamodules.segmentation import MUADDataModule -from torch_uncertainty.utils import TULightningCLI class SegFormerCLI(TULightningCLI): diff --git a/pyproject.toml b/pyproject.toml index 8ef51b13..eddd4df8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.2.2.post2" +version = "0.3.0" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, @@ -27,6 +27,8 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3 :: Only", ] dependencies = [ @@ -36,22 +38,25 @@ dependencies = [ "tensorboard", "einops", "torchinfo", - "scipy", "huggingface-hub", "scikit-learn", - "matplotlib==3.5.2", - "numpy<2", - "opencv-python", - "glest==0.0.1a0", + "matplotlib", + "numpy", "rich>=10.2.2", + "seaborn", ] [project.optional-dependencies] -image = ["scikit-image", "h5py"] +image = [ + "scikit-image", + "h5py", + "opencv-python", + "Wand", +] tabular = ["pandas"] dev = [ "torch_uncertainty[image]", - "ruff==0.5.2", + "ruff==0.6.9", "pytest-cov", "pre-commit", "pre-commit-hooks", @@ -66,7 +71,9 @@ docs = [ ] all = [ "torch_uncertainty[dev,docs,image,tabular]", - "laplace-torch" + "laplace-torch", + "glest==0.0.1a1", + "scipy", ] [project.urls] diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 535cd567..c76efc82 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -32,6 +32,7 @@ def __new__( with_feats: bool = True, ood_criterion: str = "msp", eval_ood: bool = False, + eval_shift: bool = False, eval_grouping_loss: bool = False, calibrate: bool = False, save_in_csv: bool = False, @@ -79,6 +80,7 @@ def __new__( mixup_params=mixup_params, ood_criterion=ood_criterion, eval_ood=eval_ood, + eval_shift=eval_shift, eval_grouping_loss=eval_grouping_loss, post_processing=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, @@ -98,6 +100,7 @@ def __new__( is_ensemble=True, ood_criterion=ood_criterion, eval_ood=eval_ood, + eval_shift=eval_shift, eval_grouping_loss=eval_grouping_loss, post_processing=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 9dc59dfd..66555c6c 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from torchvision import tv_tensors -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from .dataset import ( DummPixelRegressionDataset, @@ -29,6 +29,7 @@ def __init__( num_classes: int = 2, num_workers: int = 1, eval_ood: bool = False, + eval_shift: bool = False, pin_memory: bool = True, persistent_workers: bool = True, num_images: int = 2, @@ -43,11 +44,13 @@ def __init__( ) self.eval_ood = eval_ood + self.eval_shift = eval_shift self.num_classes = num_classes self.num_images = num_images self.dataset = DummyClassificationDataset self.ood_dataset = DummyClassificationDataset + self.shift_dataset = DummyClassificationDataset self.train_transform = T.ToTensor() self.test_transform = T.ToTensor() @@ -90,11 +93,22 @@ def setup(self, stage: str | None = None) -> None: transform=self.test_transform, num_images=self.num_images, ) + self.shift = self.shift_dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + self.shift.shift_severity = 1 def test_dataloader(self) -> DataLoader | list[DataLoader]: dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader def _get_train_data(self) -> ArrayLike: @@ -163,6 +177,8 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]: class DummySegmentationDataModule(TUDataModule): num_channels = 3 training_task = "segmentation" + mean = [0.0, 0.0, 0.0] + std = [1.0, 1.0, 1.0] def __init__( self, diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index 64944684..7f34ac7b 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -13,10 +13,11 @@ def test_cifar10_main(self): dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR10 - assert isinstance(dm.train_transform.transforms[1], Cutout) + assert isinstance(dm.train_transform.transforms[2], Cutout) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() @@ -34,19 +35,17 @@ def test_cifar10_main(self): dm.test_dataloader() dm.eval_ood = True + dm.eval_shift = True dm.prepare_data() dm.setup("test") dm.test_dataloader() dm = CIFAR10DataModule( - root="./data/", batch_size=128, cutout=16, test_alt="c" - ) - dm.dataset = DummyClassificationDataset - with pytest.raises(ValueError): - dm.setup() - - dm = CIFAR10DataModule( - root="./data/", batch_size=128, cutout=16, test_alt="h" + root="./data/", + batch_size=128, + cutout=16, + test_alt="h", + basic_augment=False, ) dm.dataset = DummyClassificationDataset dm.setup("test") @@ -73,6 +72,13 @@ def test_cifar10_main(self): auto_augment="rand-m9-n2-mstd0.5", ) + with pytest.raises(ValueError, match="Test set "): + dm = CIFAR10DataModule( + root="./data/", + batch_size=128, + test_alt="x", + ) + dm = CIFAR10DataModule( root="./data/", batch_size=128, diff --git a/tests/datamodules/classification/test_cifar100.py b/tests/datamodules/classification/test_cifar100.py index f2f00aa2..bdb32a56 100644 --- a/tests/datamodules/classification/test_cifar100.py +++ b/tests/datamodules/classification/test_cifar100.py @@ -13,10 +13,11 @@ def test_cifar100(self): dm = CIFAR100DataModule(root="./data/", batch_size=128, cutout=16) assert dm.dataset == CIFAR100 - assert isinstance(dm.train_transform.transforms[1], Cutout) + assert isinstance(dm.train_transform.transforms[2], Cutout) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() @@ -26,29 +27,22 @@ def test_cifar100(self): dm.test_dataloader() dm.eval_ood = True + dm.eval_shift = True dm.prepare_data() dm.setup("test") dm.test_dataloader() - dm = CIFAR100DataModule( - root="./data/", batch_size=128, cutout=0, test_alt="c" - ) - dm.dataset = DummyClassificationDataset - dm.setup("test") - with pytest.raises(ValueError): - dm.setup() - dm = CIFAR100DataModule( root="./data/", batch_size=128, cutout=0, - test_alt=None, val_split=0.1, num_dataloaders=2, + basic_augment=False, ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.setup() dm.setup("test") dm.train_dataloader() diff --git a/tests/datamodules/classification/test_imagenet.py b/tests/datamodules/classification/test_imagenet.py index 9088a701..0d514d82 100644 --- a/tests/datamodules/classification/test_imagenet.py +++ b/tests/datamodules/classification/test_imagenet.py @@ -24,6 +24,7 @@ def test_imagenet(self): dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=path) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.setup("fit") dm.setup("test") dm.train_dataloader() @@ -37,6 +38,7 @@ def test_imagenet(self): dm.test_dataloader() dm.eval_ood = True + dm.eval_shift = True dm.prepare_data() dm.setup("test") dm.test_dataloader() @@ -46,6 +48,7 @@ def test_imagenet(self): batch_size=128, val_split=path, rand_augment_opt="rand-m9-n1", + basic_augment=False, ) with pytest.raises(ValueError): diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index ba30fad3..7a967edc 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -20,7 +20,7 @@ def test_mnist_cutout(self): ) assert dm.dataset == MNIST - assert isinstance(dm.train_transform.transforms[1], Cutout) + assert isinstance(dm.train_transform.transforms[2], Cutout) dm = MNISTDataModule( root="./data/", @@ -28,20 +28,17 @@ def test_mnist_cutout(self): ood_ds="notMNIST", cutout=0, val_split=0, + basic_augment=False, ) - assert isinstance(dm.train_transform.transforms[1], nn.Identity) + assert isinstance(dm.train_transform.transforms[2], nn.Identity) with pytest.raises(ValueError): MNISTDataModule(root="./data/", batch_size=128, ood_ds="other") - MNISTDataModule(root="./data/", batch_size=128, test_alt="c") - dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset - - dm.prepare_data() - dm.setup() - + dm.setup("fit") + dm.setup("test") dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() diff --git a/tests/datamodules/classification/test_tiny_imagenet.py b/tests/datamodules/classification/test_tiny_imagenet.py index 007b5f4d..6f70313f 100644 --- a/tests/datamodules/classification/test_tiny_imagenet.py +++ b/tests/datamodules/classification/test_tiny_imagenet.py @@ -21,7 +21,10 @@ def test_tiny_imagenet(self): ) dm = TinyImageNetDataModule( - root="./data/", batch_size=128, ood_ds="textures" + root="./data/", + batch_size=128, + ood_ds="textures", + basic_augment=False, ) with pytest.raises(ValueError): @@ -31,6 +34,7 @@ def test_tiny_imagenet(self): dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.prepare_data() dm.setup() @@ -43,6 +47,7 @@ def test_tiny_imagenet(self): dm.setup("other") dm.eval_ood = True + dm.eval_shift = True dm.prepare_data() dm.setup("test") dm.test_dataloader() @@ -52,7 +57,9 @@ def test_tiny_imagenet(self): ) dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.eval_ood = True + dm.eval_shift = True dm.prepare_data() dm.setup("test") diff --git a/tests/datamodules/segmentation/test_camvid.py b/tests/datamodules/segmentation/test_camvid.py index 23b1d4d3..eaccb088 100644 --- a/tests/datamodules/segmentation/test_camvid.py +++ b/tests/datamodules/segmentation/test_camvid.py @@ -9,7 +9,12 @@ class TestCamVidDataModule: """Testing the CamVidDataModule datamodule.""" def test_camvid_main(self): - dm = CamVidDataModule(root="./data/", batch_size=128) + dm = CamVidDataModule( + root="./data/", batch_size=128, group_classes=False + ) + dm = CamVidDataModule( + root="./data/", batch_size=128, basic_augment=False + ) assert dm.dataset == CamVid diff --git a/tests/datamodules/segmentation/test_cityscapes.py b/tests/datamodules/segmentation/test_cityscapes.py index 25b0bbd1..46781c8c 100644 --- a/tests/datamodules/segmentation/test_cityscapes.py +++ b/tests/datamodules/segmentation/test_cityscapes.py @@ -10,6 +10,9 @@ class TestCityscapesDataModule: def test_camvid_main(self): dm = CityscapesDataModule(root="./data/", batch_size=128) + dm = CityscapesDataModule( + root="./data/", batch_size=128, basic_augment=False + ) assert dm.dataset == Cityscapes diff --git a/tests/datamodules/classification/test_uci_regression.py b/tests/datamodules/test_uci_regression.py similarity index 100% rename from tests/datamodules/classification/test_uci_regression.py rename to tests/datamodules/test_uci_regression.py diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index 3568dcdb..11989804 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -11,7 +11,7 @@ @pytest.fixture() def feat_input() -> torch.Tensor: - return torch.rand((6, 1)) + return torch.rand((6, 1)) # (Cin, Lin) @pytest.fixture() @@ -19,6 +19,11 @@ def feat_input_one_rearrange() -> torch.Tensor: return torch.rand((1 * 3, 5)) +@pytest.fixture() +def feat_input_16_features() -> torch.Tensor: + return torch.rand((2, 16)) + + @pytest.fixture() def seq_input() -> torch.Tensor: return torch.rand((5, 6, 3)) @@ -37,8 +42,11 @@ def voxels_input() -> torch.Tensor: class TestPackedLinear: """Testing the PackedLinear layer class.""" + # Legacy tests def test_linear_one_estimator_no_rearrange(self, feat_input: torch.Tensor): - layer = PackedLinear(6, 2, alpha=1, num_estimators=1, rearrange=False) + layer = PackedLinear( + 6, 2, alpha=1, num_estimators=1, rearrange=False, bias=False + ) out = layer(feat_input) assert out.shape == torch.Size([2, 1]) @@ -60,6 +68,48 @@ def test_linear_two_estimator_rearrange_not_divisible(self): out = layer(feat) assert out.shape == torch.Size([6, 1]) + def test_linear_full_implementation( + self, feat_input_16_features: torch.Tensor + ): + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=1, implementation="full" + ) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([2, 4]) + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=2, implementation="full" + ) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([2, 4]) + + def test_linear_sparse_implementation( + self, feat_input_16_features: torch.Tensor + ): + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=1, implementation="sparse" + ) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([2, 4]) + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=2, implementation="sparse" + ) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([2, 4]) + + def test_linear_einsum_implementation( + self, feat_input_16_features: torch.Tensor + ): + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=1, implementation="einsum" + ) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([2, 4]) + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=2, implementation="einsum" + ) + out = layer(feat_input_16_features) + assert out.shape == torch.Size([2, 4]) + def test_linear_extend(self): _ = PackedConv2d( 5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1 @@ -91,6 +141,23 @@ def test_linear_failures(self): 5, 2, alpha=1, num_estimators=1, gamma=-1, rearrange=True ) + with pytest.raises(AssertionError): + _ = PackedLinear( + 5, + 2, + alpha=1, + num_estimators=1, + gamma=1, + implementation="invalid", + ) + + with pytest.raises(ValueError): + layer = PackedLinear( + 16, 4, alpha=1, num_estimators=1, implementation="full" + ) + layer.implementation = "invalid" + _ = layer(torch.rand((2, 16))) + class TestPackedConv1d: """Testing the PackedConv1d layer class.""" diff --git a/tests/losses/test_classification.py b/tests/losses/test_classification.py index f5bb2400..fa79e562 100644 --- a/tests/losses/test_classification.py +++ b/tests/losses/test_classification.py @@ -5,6 +5,7 @@ ConfidencePenaltyLoss, ConflictualLoss, DECLoss, + FocalLoss, ) @@ -106,3 +107,27 @@ def test_failures(self): ValueError, match="is not a valid value for reduction." ): ConflictualLoss(reduction="median") + + +class TestFocalLoss: + """Testing the FocalLoss class.""" + + def test_main(self): + loss = FocalLoss(gamma=1, reduction="sum") + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = FocalLoss(gamma=0.5) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + loss = FocalLoss(gamma=0.5, reduction=None) + loss(torch.tensor([[0.0, 0.0]]), torch.tensor([0])) + + def test_failures(self): + with pytest.raises( + ValueError, + match="The gamma term of the focal loss should be non-negative, but got", + ): + FocalLoss(gamma=-1) + + with pytest.raises( + ValueError, match="is not a valid value for reduction." + ): + FocalLoss(gamma=1, reduction="median") diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index cf94da78..ccc29ca2 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -15,11 +15,12 @@ def test_plot_binary(self) -> None: torch.as_tensor([0, 0, 1, 1, 1]), ) fig, ax = metric.plot() - metric.plot(ax=ax) assert isinstance(fig, plt.Figure) - assert isinstance(ax, plt.Axes) - assert ax.get_xlabel() == "Top-class Confidence (%)" - assert ax.get_ylabel() == "Success Rate (%)" + assert ax[0].get_xlabel() == "Top-class Confidence (%)" + assert ax[0].get_ylabel() == "Success Rate (%)" + assert ax[1].get_xlabel() == "Top-class Confidence (%)" + assert ax[1].get_ylabel() == "Density" + plt.close(fig) def test_plot_multiclass( @@ -41,9 +42,10 @@ def test_plot_multiclass( ) fig, ax = metric.plot() assert isinstance(fig, plt.Figure) - assert isinstance(ax, plt.Axes) - assert ax.get_xlabel() == "Top-class Confidence (%)" - assert ax.get_ylabel() == "Success Rate (%)" + assert ax[0].get_xlabel() == "Top-class Confidence (%)" + assert ax[0].get_ylabel() == "Success Rate (%)" + assert ax[1].get_xlabel() == "Top-class Confidence (%)" + assert ax[1].get_ylabel() == "Density" plt.close(fig) def test_errors(self) -> None: diff --git a/tests/models/test_wideresnets.py b/tests/models/test_wideresnets.py index 2bc3cbfe..72c518fc 100644 --- a/tests/models/test_wideresnets.py +++ b/tests/models/test_wideresnets.py @@ -47,6 +47,7 @@ def test_main(self): widen_factor=20, in_channels=3, num_classes=10, + num_estimators=4, conv_bias=False, dropout_rate=0.0, ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 999a9663..14fa99d9 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -8,10 +8,10 @@ DummyClassificationDataModule, dummy_model, ) +from torch_uncertainty import TUTrainer from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.routines import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget -from torch_uncertainty.utils import TUTrainer class TestClassification: @@ -72,6 +72,7 @@ def test_one_estimator_two_classes(self): num_classes=2, num_images=100, eval_ood=True, + eval_shift=True, ) model = DummyClassificationBaseline( num_classes=dm.num_classes, @@ -80,6 +81,7 @@ def test_one_estimator_two_classes(self): baseline_type="single", ood_criterion="entropy", eval_ood=True, + eval_shift=True, no_mixup_params=True, ) diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py index 56e2058d..c60b8d7a 100644 --- a/tests/routines/test_pixel_regression.py +++ b/tests/routines/test_pixel_regression.py @@ -8,13 +8,13 @@ DummyPixelRegressionBaseline, DummyPixelRegressionDataModule, ) +from torch_uncertainty import TUTrainer from torch_uncertainty.losses import DistributionNLLLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines.pixel_regression import ( PixelRegressionRoutine, colorize, ) -from torch_uncertainty.utils import TUTrainer class TestPixelRegression: diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 7c03ab1e..8f1fa6ed 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -4,10 +4,10 @@ from torch import nn from tests._dummies import DummyRegressionBaseline, DummyRegressionDataModule +from torch_uncertainty import TUTrainer from torch_uncertainty.losses import DistributionNLLLoss from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import RegressionRoutine -from torch_uncertainty.utils import TUTrainer class TestRegression: diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 7168e607..beed8a1d 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -7,9 +7,9 @@ DummySegmentationBaseline, DummySegmentationDataModule, ) +from torch_uncertainty import TUTrainer from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import SegmentationRoutine -from torch_uncertainty.utils import TUTrainer class TestSegmentation: diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 4d979f89..4e1a5e59 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -3,16 +3,24 @@ from requests.exceptions import HTTPError from torch_uncertainty.transforms.corruption import ( + Brightness, + Contrast, DefocusBlur, + Elastic, + Fog, Frost, GaussianBlur, GaussianNoise, GlassBlur, ImpulseNoise, JPEGCompression, + MotionBlur, Pixelate, + Saturation, ShotNoise, + Snow, SpeckleNoise, + ZoomBlur, ) @@ -32,100 +40,74 @@ def test_gaussian_noise(self): print(transform) def test_shot_noise(self): - with pytest.raises(ValueError): - _ = ShotNoise(-1) - with pytest.raises(TypeError): - _ = ShotNoise(0.1) inputs = torch.rand(3, 32, 32) transform = ShotNoise(1) transform(inputs) transform = ShotNoise(0) transform(inputs) - print(transform) def test_impulse_noise(self): - with pytest.raises(ValueError): - _ = ImpulseNoise(-1) - with pytest.raises(TypeError): - _ = ImpulseNoise(0.1) inputs = torch.rand(3, 32, 32) transform = ImpulseNoise(1) transform(inputs) transform = ImpulseNoise(0) transform(inputs) - print(transform) def test_speckle_noise(self): - with pytest.raises(ValueError): - _ = SpeckleNoise(-1) - with pytest.raises(TypeError): - _ = SpeckleNoise(0.1) inputs = torch.rand(3, 32, 32) transform = SpeckleNoise(1) transform(inputs) transform = SpeckleNoise(0) transform(inputs) - print(transform) def test_gaussian_blur(self): - with pytest.raises(ValueError): - _ = GaussianBlur(-1) - with pytest.raises(TypeError): - _ = GaussianBlur(0.1) inputs = torch.rand(3, 32, 32) transform = GaussianBlur(1) transform(inputs) transform = GaussianBlur(0) transform(inputs) - print(transform) def test_glass_blur(self): - with pytest.raises(ValueError): - _ = GlassBlur(-1) - with pytest.raises(TypeError): - _ = GlassBlur(0.1) inputs = torch.rand(3, 32, 32) transform = GlassBlur(1) transform(inputs) transform = GlassBlur(0) transform(inputs) - print(transform) def test_defocus_blur(self): - with pytest.raises(ValueError): - _ = DefocusBlur(-1) - with pytest.raises(TypeError): - _ = DefocusBlur(0.1) inputs = torch.rand(3, 32, 32) transform = DefocusBlur(1) transform(inputs) transform = DefocusBlur(0) transform(inputs) - print(transform) + + def test_motion_blur(self): + inputs = torch.rand(3, 32, 32) + transform = MotionBlur(1) + transform(inputs) + transform = MotionBlur(0) + transform(inputs) + + def test_zoom_blur(self): + inputs = torch.rand(3, 32, 32) + transform = ZoomBlur(1) + transform(inputs) + transform = ZoomBlur(0) + transform(inputs) def test_jpeg_compression(self): - with pytest.raises(ValueError): - _ = JPEGCompression(-1) - with pytest.raises(TypeError): - _ = JPEGCompression(0.1) inputs = torch.rand(3, 32, 32) transform = JPEGCompression(1) transform(inputs) transform = JPEGCompression(0) transform(inputs) - print(transform) def test_pixelate(self): - with pytest.raises(ValueError): - _ = Pixelate(-1) - with pytest.raises(TypeError): - _ = Pixelate(0.1) inputs = torch.rand(3, 32, 32) transform = Pixelate(1) transform(inputs) transform = Pixelate(0) transform(inputs) - print(transform) def test_frost(self): try: @@ -134,13 +116,50 @@ def test_frost(self): except HTTPError: frost_ok = False if frost_ok: - with pytest.raises(ValueError): - _ = Frost(-1) - with pytest.raises(TypeError): - _ = Frost(0.1) inputs = torch.rand(3, 32, 32) transform = Frost(1) transform(inputs) transform = Frost(0) transform(inputs) - print(transform) + + def test_snow(self): + inputs = torch.rand(3, 32, 32) + transform = Snow(1) + transform(inputs) + transform = Snow(0) + transform(inputs) + + def test_fog(self): + inputs = torch.rand(3, 32, 32) + transform = Fog(1, size=32) + transform(inputs) + transform = Fog(0, size=32) + transform(inputs) + + def test_brightness(self): + inputs = torch.rand(3, 32, 32) + transform = Brightness(1) + transform(inputs) + transform = Brightness(0) + transform(inputs) + + def test_contrast(self): + inputs = torch.rand(3, 32, 32) + transform = Contrast(1) + transform(inputs) + transform = Contrast(0) + transform(inputs) + + def test_elastic(self): + inputs = torch.rand(3, 32, 32) + transform = Elastic(1) + transform(inputs) + transform = Elastic(0) + transform(inputs) + + def test_saturation(self): + inputs = torch.rand(3, 32, 32) + transform = Saturation(1) + transform(inputs) + transform = Saturation(0) + transform(inputs) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index e69de29b..4182d817 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .utils import TULightningCLI, TUTrainer diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index e6188322..eef4abe3 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -22,6 +22,7 @@ def __init__( checkpoint_ids: list[int], backbone: Literal["resnet", "vgg", "wideresnet"], eval_ood: bool = False, + eval_shift: bool = False, eval_grouping_loss: bool = False, ood_criterion: Literal[ "msp", "logit", "energy", "entropy", "mi", "vr" @@ -53,6 +54,7 @@ def __init__( loss=None, is_ensemble=de.num_estimators > 1, eval_ood=eval_ood, + eval_shift=eval_shift, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 15d1655d..f1c5c486 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -73,6 +73,7 @@ def __init__( save_in_csv: bool = False, calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, + eval_shift: bool = False, eval_grouping_loss: bool = False, num_calibration_bins: int = 15, pretrained: bool = False, @@ -156,6 +157,8 @@ def __init__( ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. num_calibration_bins (int, optional): Number of calibration bins. @@ -235,6 +238,7 @@ def __init__( format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, + eval_shift=eval_shift, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index f9c29323..7dc65a59 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -45,6 +45,7 @@ def __init__( save_in_csv: bool = False, calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, + eval_shift: bool = False, eval_grouping_loss: bool = False, ) -> None: r"""VGG backbone baseline for classification providing support for @@ -105,6 +106,8 @@ def __init__( ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. @@ -173,6 +176,7 @@ def __init__( optim_recipe=optim_recipe, mixup_params=mixup_params, eval_ood=eval_ood, + eval_shift=eval_shift, ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index b1086200..a51d08cf 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -56,6 +56,7 @@ def __init__( save_in_csv: bool = False, calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, + eval_shift: bool = False, eval_grouping_loss: bool = False, ) -> None: r"""Wide-ResNet28x10 backbone baseline for classification providing support @@ -119,6 +120,8 @@ def __init__( ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. @@ -191,6 +194,7 @@ def __init__( optim_recipe=optim_recipe, mixup_params=mixup_params, eval_ood=eval_ood, + eval_shift=eval_shift, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 24859701..670dfc4c 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 +from .abstract import TUDataModule from .classification.cifar10 import CIFAR10DataModule from .classification.cifar100 import CIFAR100DataModule from .classification.imagenet import ImageNetDataModule diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 16308d72..75d303e4 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -19,7 +19,7 @@ def __init__( self, root: str | Path, batch_size: int, - val_split: float, + val_split: float | None, num_workers: int, pin_memory: bool, persistent_workers: bool, @@ -84,7 +84,7 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - List[DataLoader]: test set for in distribution data + list[DataLoader]: test set for in distribution data and out-of-distribution data. """ return [self._data_loader(self.test)] diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index a87d61c0..f33d4a01 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -21,19 +21,22 @@ class CIFAR10DataModule(TUDataModule): num_channels = 3 input_shape = (3, 32, 32) training_task = "classification" + mean = (0.4914, 0.4822, 0.4465) + std = (0.2023, 0.1994, 0.2010) def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, + eval_shift: bool = False, + shift_severity: int = 1, val_split: float | None = None, num_workers: int = 1, basic_augment: bool = True, cutout: int | None = None, auto_augment: str | None = None, - test_alt: Literal["c", "h"] | None = None, - corruption_severity: int = 1, + test_alt: Literal["h"] | None = None, num_dataloaders: int = 1, pin_memory: bool = True, persistent_workers: bool = True, @@ -43,6 +46,9 @@ def __init__( Args: root (str): Root directory of the datasets. eval_ood (bool): Whether to evaluate on out-of-distribution data. + Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. @@ -55,7 +61,7 @@ def __init__( ``False``. auto_augment (str): Which auto-augment to apply. Defaults to ``None``. test_alt (str): Which test set to use. Defaults to ``None``. - corruption_severity (int): Severity of corruption to apply for + shift_severity (int): Severity of corruption to apply for CIFAR10-C. Defaults to ``1``. num_dataloaders (int): Number of dataloaders to use. Defaults to ``1``. pin_memory (bool): Whether to pin memory. Defaults to ``True``. @@ -74,17 +80,19 @@ def __init__( self.val_split = val_split self.num_dataloaders = num_dataloaders self.eval_ood = eval_ood + self.eval_shift = eval_shift - if test_alt == "c": - self.dataset = CIFAR10C - elif test_alt == "h": + if test_alt == "h": self.dataset = CIFAR10H - else: + elif test_alt is None: self.dataset = CIFAR10 + else: + raise ValueError(f"Test set {test_alt} is not supported.") self.test_alt = test_alt - self.corruption_severity = corruption_severity + self.shift_severity = shift_severity self.ood_dataset = SVHN + self.shift_dataset = CIFAR10C if (cutout is not None) + int(auto_augment is not None) > 1: raise ValueError( @@ -111,12 +119,12 @@ def __init__( self.train_transform = T.Compose( [ + T.ToTensor(), basic_transform, main_transform, - T.ToTensor(), T.Normalize( - (0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010), + self.mean, + self.std, ), ] ) @@ -125,8 +133,8 @@ def __init__( [ T.ToTensor(), T.Normalize( - (0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010), + self.mean, + self.std, ), ] ) @@ -135,12 +143,6 @@ def prepare_data(self) -> None: # coverage: ignore if self.test_alt is None: self.dataset(self.root, train=True, download=True) self.dataset(self.root, train=False, download=True) - elif self.test_alt == "c": - self.dataset( - self.root, - severity=self.corruption_severity, - download=True, - ) else: self.dataset( self.root, @@ -149,6 +151,12 @@ def prepare_data(self) -> None: # coverage: ignore if self.eval_ood: self.ood_dataset(self.root, split="test", download=True) + if self.eval_shift: + self.shift_dataset( + self.root, + shift_severity=self.shift_severity, + download=True, + ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: @@ -187,7 +195,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: self.test = self.dataset( self.root, transform=self.test_transform, - severity=self.corruption_severity, + shift_severity=self.shift_severity, ) if self.eval_ood: self.ood = self.ood_dataset( @@ -196,6 +204,12 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.test_transform, ) + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + download=False, + transform=self.test_transform, + ) if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") @@ -216,12 +230,14 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - List[DataLoader]: test set for in distribution data + list[DataLoader]: test set for in distribution data and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index fa759853..5b6a5487 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout @@ -22,19 +22,21 @@ class CIFAR100DataModule(TUDataModule): num_channels = 3 input_shape = (3, 32, 32) training_task = "classification" + mean = (0.5071, 0.4867, 0.4408) + std = (0.2675, 0.2565, 0.2761) def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, + eval_shift: bool = False, + shift_severity: int = 1, val_split: float | None = None, basic_augment: bool = True, cutout: int | None = None, randaugment: bool = False, auto_augment: str | None = None, - test_alt: Literal["c"] | None = None, - corruption_severity: int = 1, num_dataloaders: int = 1, num_workers: int = 1, pin_memory: bool = True, @@ -46,6 +48,8 @@ def __init__( root (str): Root directory of the datasets. eval_ood (bool): Whether to evaluate out-of-distribution performance. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. @@ -55,8 +59,7 @@ def __init__( randaugment (bool): Whether to apply RandAugment. Defaults to ``False``. auto_augment (str): Which auto-augment to apply. Defaults to ``None``. - test_alt (str): Which test set to use. Defaults to ``None``. - corruption_severity (int): Severity of corruption to apply to + shift_severity (int): Severity of corruption to apply to CIFAR100-C. Defaults to ``1``. num_dataloaders (int): Number of dataloaders to use. Defaults to ``1``. num_workers (int): Number of workers to use for data loading. Defaults @@ -75,18 +78,14 @@ def __init__( ) self.eval_ood = eval_ood + self.eval_shift = eval_shift self.num_dataloaders = num_dataloaders - if test_alt == "c": - self.dataset = CIFAR100C - else: - self.dataset = CIFAR100 - - self.test_alt = test_alt - + self.dataset = CIFAR100 self.ood_dataset = SVHN + self.shift_dataset = CIFAR100C - self.corruption_severity = corruption_severity + self.shift_severity = shift_severity if (cutout is not None) + randaugment + int( auto_augment is not None @@ -117,30 +116,23 @@ def __init__( self.train_transform = T.Compose( [ + T.ToTensor(), basic_transform, main_transform, - T.ToTensor(), T.ConvertImageDtype(torch.float32), - T.Normalize( - (0.5071, 0.4867, 0.4408), - (0.2675, 0.2565, 0.2761), - ), + T.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = T.Compose( [ T.ToTensor(), - T.Normalize( - (0.5071, 0.4867, 0.4408), - (0.2675, 0.2565, 0.2761), - ), + T.Normalize(mean=self.mean, std=self.std), ] ) def prepare_data(self) -> None: # coverage: ignore - if self.test_alt is None: - self.dataset(self.root, train=True, download=True) - self.dataset(self.root, train=False, download=True) + self.dataset(self.root, train=True, download=True) + self.dataset(self.root, train=False, download=True) if self.eval_ood: self.ood_dataset( @@ -149,11 +141,16 @@ def prepare_data(self) -> None: # coverage: ignore download=True, transform=self.test_transform, ) + if self.eval_shift: + self.shift_dataset( + self.root, + download=True, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: - if self.test_alt == "c": - raise ValueError("CIFAR-C can only be used in testing.") full = self.dataset( self.root, train=True, @@ -175,24 +172,24 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.test_transform, ) if stage == "test" or stage is None: - if self.test_alt is None: - self.test = self.dataset( + self.test = self.dataset( + self.root, + train=False, + download=False, + transform=self.test_transform, + ) + if self.eval_ood: + self.ood = self.ood_dataset( self.root, - train=False, + split="test", download=False, transform=self.test_transform, ) - else: - self.test = self.dataset( - self.root, - transform=self.test_transform, - severity=self.corruption_severity, - ) - if self.eval_ood: - self.ood = self.ood_dataset( + if self.eval_shift: + self.shift = self.shift_dataset( self.root, - split="test", download=False, + shift_severity=self.shift_severity, transform=self.test_transform, ) if stage not in ["fit", "test", None]: @@ -215,12 +212,14 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - List[DataLoader]: test set for in distribution data + list[DataLoader]: test set for in distribution data and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 1e19ed4a..f56bc34b 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -10,9 +10,10 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import ( ImageNetA, + ImageNetC, ImageNetO, ImageNetR, OpenImageO, @@ -35,6 +36,8 @@ class ImageNetDataModule(TUDataModule): "openimage-o", ] training_task = "classification" + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) train_indices = None val_indices = None @@ -43,6 +46,8 @@ def __init__( root: str | Path, batch_size: int, eval_ood: bool = False, + eval_shift: bool = False, + shift_severity: int = 1, val_split: float | Path | None = None, ood_ds: str = "openimage-o", test_alt: str | None = None, @@ -60,7 +65,10 @@ def __init__( Args: root (str): Root directory of the datasets. eval_ood (bool): Whether to evaluate out-of-distribution - performance. + performance. Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. + shift_severity: int = 1, batch_size (int): Number of samples per batch. val_split (float or Path): Share of samples to use for validation or path to a yaml file containing a list of validation images @@ -80,7 +88,6 @@ def __init__( pin_memory (bool): Whether to pin memory. Defaults to ``True``. persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. - kwargs: Additional arguments. """ super().__init__( root=Path(root), @@ -92,6 +99,8 @@ def __init__( ) self.eval_ood = eval_ood + self.eval_shift = eval_shift + self.shift_severity = shift_severity if val_split and not isinstance(val_split, float): val_split = Path(val_split) self.train_indices, self.val_indices = read_indices(val_split) @@ -123,6 +132,7 @@ def __init__( self.ood_dataset = OpenImageO else: raise ValueError(f"The dataset {ood_ds} is not supported.") + self.shift_dataset = ImageNetC self.procedure = procedure @@ -159,19 +169,19 @@ def __init__( self.train_transform = T.Compose( [ + T.ToTensor(), basic_transform, main_transform, - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = T.Compose( [ + T.ToTensor(), T.Resize(256, interpolation=self.interpolation), T.CenterCrop(224), - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.Normalize(mean=self.mean, std=self.std), ] ) @@ -211,6 +221,13 @@ def prepare_data(self) -> None: # coverage: ignore download=True, transform=self.test_transform, ) + if self.eval_shift: + self.shift_dataset( + self.root, + download=True, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: @@ -264,16 +281,26 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=True, ) + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + download=False, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) + def test_dataloader(self) -> list[DataLoader]: """Get the test dataloaders for ImageNet. Return: - List[DataLoader]: ImageNet test set (in distribution data) and + list[DataLoader]: ImageNet test set (in distribution data) and Textures test split (out-of-distribution data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader @@ -284,7 +311,7 @@ def read_indices(path: Path) -> list[str]: # coverage: ignore path (Path): Path to the file. Returns: - list[str]: List of filenames. + list[str]: list of filenames. """ if not path.is_file(): raise ValueError(f"{path} is not a file.") diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 9be45ea6..604ac8c2 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split @@ -18,18 +18,20 @@ class MNISTDataModule(TUDataModule): input_shape = (1, 28, 28) training_task = "classification" ood_datasets = ["fashion", "notMNIST"] + mean = (0.1307,) + std = (0.3081,) def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, + eval_shift: bool = False, ood_ds: Literal["fashion", "notMNIST"] = "fashion", val_split: float | None = None, num_workers: int = 1, basic_augment: bool = True, cutout: int | None = None, - test_alt: Literal["c"] | None = None, pin_memory: bool = True, persistent_workers: bool = True, ) -> None: @@ -38,6 +40,9 @@ def __init__( Args: root (str): Root directory of the datasets. eval_ood (bool): Whether to evaluate on out-of-distribution data. + Defaults to ``False``. + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. batch_size (int): Number of samples per batch. ood_ds (str): Which out-of-distribution dataset to use. Defaults to ``"fashion"``; `fashion` stands for FashionMNIST and `notMNIST` for @@ -49,11 +54,9 @@ def __init__( basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. - test_alt (str): Which test set to use. Defaults to ``None``. pin_memory (bool): Whether to pin memory. Defaults to ``True``. persistent_workers (bool): Whether to use persistent workers. Defaults to ``True``. - kwargs: Additional arguments. """ super().__init__( root=root, @@ -65,12 +68,10 @@ def __init__( ) self.eval_ood = eval_ood + self.eval_shift = eval_shift self.batch_size = batch_size - if test_alt == "c": - self.dataset = MNISTC - else: - self.dataset = MNIST + self.dataset = MNIST if ood_ds == "fashion": self.ood_dataset = FashionMNIST @@ -80,6 +81,8 @@ def __init__( raise ValueError( f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}." ) + self.shift_dataset = MNISTC + self.shift_severity = 1 if basic_augment: basic_transform = T.RandomCrop(28, padding=4) @@ -90,26 +93,26 @@ def __init__( self.train_transform = T.Compose( [ + T.ToTensor(), basic_transform, main_transform, - T.ToTensor(), - T.Normalize((0.1307,), (0.3081,)), + T.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = T.Compose( [ T.ToTensor(), T.CenterCrop(28), - T.Normalize((0.1307,), (0.3081,)), + T.Normalize(mean=self.mean, std=self.std), ] ) if self.eval_ood: # NotMNIST has 3 channels self.ood_transform = T.Compose( [ - T.Grayscale(num_output_channels=1), T.ToTensor(), + T.Grayscale(num_output_channels=1), T.CenterCrop(28), - T.Normalize((0.1307,), (0.3081,)), + T.Normalize(mean=self.mean, std=self.std), ] ) @@ -120,6 +123,8 @@ def prepare_data(self) -> None: # coverage: ignore if self.eval_ood: self.ood_dataset(self.root, download=True) + if self.eval_shift: + self.shift_dataset(self.root, download=True) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: @@ -159,12 +164,18 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: download=False, transform=self.ood_transform, ) + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + download=False, + transform=self.test_transform, + ) def test_dataloader(self) -> list[DataLoader]: r"""Get the test dataloaders for MNIST. Return: - List[DataLoader]: Dataloaders of the MNIST test set (in + list[DataLoader]: Dataloaders of the MNIST test set (in distribution data) and FashionMNIST test split (out-of-distribution data). """ diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index bec3025d..542a10fd 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -9,8 +9,12 @@ from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN -from torch_uncertainty.datamodules.abstract import TUDataModule -from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet +from torch_uncertainty.datamodules import TUDataModule +from torch_uncertainty.datasets.classification import ( + ImageNetO, + TinyImageNet, + TinyImageNetC, +) from torch_uncertainty.utils import ( create_train_val_split, interpolation_modes_from_str, @@ -21,12 +25,16 @@ class TinyImageNetDataModule(TUDataModule): num_classes = 200 num_channels = 3 training_task = "classification" + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, + eval_shift: bool = False, + shift_severity: int = 1, val_split: float | None = None, ood_ds: str = "svhn", interpolation: str = "bilinear", @@ -44,8 +52,10 @@ def __init__( pin_memory=pin_memory, persistent_workers=persistent_workers, ) - # TODO: COMPUTE STATS self.eval_ood = eval_ood + self.eval_shift = eval_shift + self.shift_severity = shift_severity + self.ood_ds = ood_ds self.interpolation = interpolation_modes_from_str(interpolation) @@ -61,7 +71,7 @@ def __init__( raise ValueError( f"OOD dataset {ood_ds} not supported for TinyImageNet." ) - + self.shift_dataset = TinyImageNetC if basic_augment: basic_transform = T.Compose( [ @@ -79,18 +89,18 @@ def __init__( self.train_transform = T.Compose( [ + T.ToTensor(), basic_transform, main_transform, - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = T.Compose( [ - T.Resize(64, interpolation=self.interpolation), T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.Resize(64, interpolation=self.interpolation), + T.Normalize(mean=self.mean, std=self.std), ] ) @@ -133,6 +143,13 @@ def prepare_data(self) -> None: # coverage: ignore ), ] ) + if self.eval_shift: + self.shift_dataset( + self.root, + download=True, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: @@ -194,6 +211,14 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.test_transform, ) + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + download=False, + shift_severity=self.shift_severity, + transform=self.test_transform, + ) + def train_dataloader(self) -> DataLoader: r"""Get the training dataloader for TinyImageNet. @@ -214,12 +239,14 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. Return: - List[DataLoader]: test set for in distribution data + list[DataLoader]: test set for in distribution data and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader def _get_train_data(self) -> ArrayLike: diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index df067d97..2e3f7865 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -7,7 +7,7 @@ from torchvision.datasets import VisionDataset from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 84f99ac7..bd4b06b7 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -1,19 +1,33 @@ +import logging from pathlib import Path import torch +from torch import nn +from torch.nn.common_types import _size_2_t +from torch.nn.modules.utils import _pair from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.segmentation import CamVid +from torch_uncertainty.transforms import RandomRescale class CamVidDataModule(TUDataModule): + num_channels = 3 + training_task = "segmentation" + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + def __init__( self, root: str | Path, batch_size: int, - val_split: float | None = None, # FIXME: not used for now + crop_size: _size_2_t = 640, + eval_size: _size_2_t = (720, 960), + group_classes: bool = True, + basic_augment: bool = True, + val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, @@ -23,6 +37,22 @@ def __init__( Args: root (str or Path): Root directory of the datasets. batch_size (int): Number of samples per batch. + crop_size (sequence or int, optional): Desired input image and + segmentation mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``640``. + eval_size (sequence or int, optional): Desired input image and + segmentation mask sizes during evaluation. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(720,960)``. + group_classes (bool, optional): Whether to group the 32 classes into + 11 superclasses. Default: ``True``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -42,7 +72,7 @@ def __init__( v2.Compose( [ - v2.Resize((360, 480)), + v2.Resize(640), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, @@ -57,6 +87,9 @@ def __init__( This behavior can be modified by overriding ``self.train_transform`` and ``self.test_transform`` after initialization. """ + if val_split is not None: # coverage: ignore + logging.warning("val_split is not used for CamVidDataModule.") + super().__init__( root=root, batch_size=batch_size, @@ -65,11 +98,37 @@ def __init__( pin_memory=pin_memory, persistent_workers=persistent_workers, ) + if group_classes: + self.num_classes = 11 + else: + self.num_classes = 32 self.dataset = CamVid + self.group_classes = group_classes + self.crop_size = _pair(crop_size) + self.eval_size = _pair(eval_size) + + if basic_augment: + basic_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter( + brightness=0.5, contrast=0.5, saturation=0.5 + ), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() self.train_transform = v2.Compose( [ - v2.Resize((360, 480)), + v2.ToImage(), + basic_transform, v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, @@ -78,11 +137,13 @@ def __init__( }, scale=True, ), + v2.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = v2.Compose( [ - v2.Resize((360, 480)), + v2.ToImage(), + v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, @@ -91,6 +152,7 @@ def __init__( }, scale=True, ), + v2.Normalize(mean=self.mean, std=self.std), ] ) @@ -102,12 +164,14 @@ def setup(self, stage: str | None = None) -> None: self.train = self.dataset( root=self.root, split="train", + group_classes=self.group_classes, download=False, transforms=self.train_transform, ) self.val = self.dataset( root=self.root, split="val", + group_classes=self.group_classes, download=False, transforms=self.test_transform, ) @@ -115,6 +179,7 @@ def setup(self, stage: str | None = None) -> None: self.test = self.dataset( root=self.root, split="test", + group_classes=self.group_classes, download=False, transforms=self.test_transform, ) diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index ad583664..75e89515 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -1,24 +1,32 @@ from pathlib import Path import torch +from torch import nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.segmentation import Cityscapes from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split class CityscapesDataModule(TUDataModule): + num_classes = 19 + num_channels = 3 + training_task = "segmentation" + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + def __init__( self, root: str | Path, batch_size: int, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), + basic_augment: bool = True, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -41,6 +49,8 @@ def __init__( :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. Defaults to ``(1024,2048)``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to @@ -110,17 +120,28 @@ def __init__( self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) + if basic_augment: + basic_transform = v2.Compose( + [ + RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + v2.RandomCrop( + size=self.crop_size, + pad_if_needed=True, + fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, + ), + v2.ColorJitter( + brightness=0.5, contrast=0.5, saturation=0.5 + ), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + self.train_transform = v2.Compose( [ v2.ToImage(), - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), - v2.RandomCrop( - size=self.crop_size, - pad_if_needed=True, - fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, - ), - v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), - v2.RandomHorizontalFlip(), + basic_transform, v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, @@ -129,9 +150,7 @@ def __init__( }, scale=True, ), - v2.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), + v2.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = v2.Compose( @@ -146,9 +165,7 @@ def __init__( }, scale=True, ), - v2.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index 9ba10ee4..5e967000 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -6,13 +6,17 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import TUDataModule +from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split class MUADDataModule(TUDataModule): + training_task = "segmentation" + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + def __init__( self, root: str | Path, @@ -128,9 +132,7 @@ def __init__( }, scale=True, ), - v2.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), + v2.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = v2.Compose( @@ -144,9 +146,7 @@ def __init__( }, scale=True, ), - v2.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index b10fa0b9..225410f8 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -23,7 +23,7 @@ class CIFAR10C(VisionDataset): takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``cifarc_subsets``. - severity (int): The severity of the corruption, between 1 and 5. + shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. @@ -88,7 +88,7 @@ def __init__( transform: Callable | None = None, target_transform: Callable | None = None, subset: str = "all", - severity: int = 1, + shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) @@ -111,34 +111,34 @@ def __init__( f"The subset '{subset}' does not exist in CIFAR-C." ) self.subset = subset - self.severity = severity + self.shift_severity = shift_severity - if severity not in list(range(1, 6)): + if shift_severity not in list(range(1, 6)): raise ValueError( - "Corruptions severity should be chosen between 1 and 5 " + "Corruptions shift_severity should be chosen between 1 and 5 " "included." ) samples, labels = self.make_dataset( - self.root, self.subset, self.severity + self.root, self.subset, self.shift_severity ) self.samples = samples self.labels = labels.astype(np.int64) def make_dataset( - self, root: Path, subset: str, severity: int + self, root: Path, subset: str, shift_severity: int ) -> tuple[np.ndarray, np.ndarray]: r"""Make the CIFAR-C dataset. Build the corrupted dataset according to the chosen subset and - severity. If the subset is 'all', gather all corruption types + shift_severity. If the subset is 'all', gather all corruption types in the dataset. Args: root (Path):The path to the dataset. subset (str): The name of the corruption subset to be used. Choose `all` for the dataset to contain all subsets. - severity (int): The severity of the corruption applied to the + shift_severity (int): The shift_severity of the corruption applied to the images. Returns: @@ -146,11 +146,11 @@ def make_dataset( """ if subset == "all": labels: np.ndarray = np.load(root / "labels.npy")[ - (severity - 1) * 10000 : severity * 10000 + (shift_severity - 1) * 10000 : shift_severity * 10000 ] sample_arrays = [ np.load(root / (cifar_subset + ".npy"))[ - (severity - 1) * 10000 : severity * 10000 + (shift_severity - 1) * 10000 : shift_severity * 10000 ] for cifar_subset in self.cifarc_subsets ] @@ -159,10 +159,10 @@ def make_dataset( else: samples: np.ndarray = np.load(root / (subset + ".npy"))[ - (severity - 1) * 10000 : severity * 10000 + (shift_severity - 1) * 10000 : shift_severity * 10000 ] labels: np.ndarray = np.load(root / "labels.npy")[ - (severity - 1) * 10000 : severity * 10000 + (shift_severity - 1) * 10000 : shift_severity * 10000 ] return samples, labels @@ -207,7 +207,7 @@ def download(self) -> None: class CIFAR100C(CIFAR10C): base_folder = "CIFAR-100-C" - tar_md5 = "11f0ed0f1191edbf9fa23466ae6021d3" + tgz_md5 = "11f0ed0f1191edbf9fa23466ae6021d3" ctest_list = [ ["fog.npy", "4efc7ebd5e82b028bdbe13048e3ea564"], ["jpeg_compression.npy", "c851b7f1324e1d2ffddeb76920576d11"], @@ -230,5 +230,5 @@ class CIFAR100C(CIFAR10C): ["labels.npy", "bb4026e9ce52996b95f439544568cdb2"], ["pixelate.npy", "96c00c60f144539e14cffb02ddbd0640"], ] - cifarc_url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar" + url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar" filename = "CIFAR-100-C.tar" diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py index 762ff346..74be5f68 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py @@ -21,7 +21,7 @@ class TinyImageNetC(ImageFolder): takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``cifarc_subsets``. - severity (int): The severity of the corruption, between 1 and 5. + shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. @@ -68,7 +68,7 @@ def __init__( transform: Callable | None = None, target_transform: Callable | None = None, subset: str = "all", - severity: int = 1, + shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) @@ -89,28 +89,28 @@ def __init__( f"The subset '{subset}' does not exist in TinyImageNet-C." ) self.subset = subset - self.severity = severity + self.shift_severity = shift_severity self.transform = transform self.target_transform = target_transform - if severity not in list(range(1, 6)): + if shift_severity not in list(range(1, 6)): raise ValueError( - "Corruptions severity should be chosen between 1 and 5 included." + "Corruptions shift_severity should be chosen between 1 and 5 included." ) - # Update samples given the subset and severity - self._make_c_dataset(self.subset, self.severity) + # Update samples given the subset and shift_severity + self._make_c_dataset(self.subset, self.shift_severity) - def _make_c_dataset(self, subset: str, severity: int) -> None: + def _make_c_dataset(self, subset: str, shift_severity: int) -> None: r"""Build the corrupted dataset according to the chosen subset and - severity. If the subset is 'all', gather all corruption types + shift_severity. If the subset is 'all', gather all corruption types in the dataset. Args: subset (str): The name of the corruption subset to be used. Choose `all` for the dataset to contain all subsets. - severity (int): The severity of the corruption applied to the + shift_severity (int): The shift_severity of the corruption applied to the images. """ if subset == "all": @@ -120,7 +120,7 @@ def _make_c_dataset(self, subset: str, severity: int) -> None: ( img[0] .replace("brightness", subset) - .replace("/1/", "/" + str(severity) + "/"), + .replace("/1/", "/" + str(shift_severity) + "/"), img[1], ) for img in self.imgs @@ -134,7 +134,7 @@ def _make_c_dataset(self, subset: str, severity: int) -> None: ( img[0] .replace("brightness", subset) - .replace("/1/", "/" + str(severity) + "/"), + .replace("/1/", "/" + str(shift_severity) + "/"), img[1], ) for img in self.imgs diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py new file mode 100644 index 00000000..8f021942 --- /dev/null +++ b/torch_uncertainty/datasets/corrupted.py @@ -0,0 +1,105 @@ +from copy import deepcopy +from pathlib import Path + +from PIL import Image +from torch import nn +from torchvision.datasets import VisionDataset +from torchvision.transforms import ToPILImage, ToTensor +from tqdm.auto import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from torch_uncertainty.transforms.corruption import corruption_transforms + + +class CorruptedDataset(VisionDataset): + def __init__( + self, + core_dataset: VisionDataset, + shift_severity: int, + on_the_fly: bool = False, + ) -> None: + super().__init__() + self.core_dataset = core_dataset + if shift_severity <= 0: + raise ValueError( + f"Severity must be greater than 0. Got {shift_severity}." + ) + self.shift_severity = shift_severity + self.core_length = len(core_dataset) + self.on_the_fly = on_the_fly + self.transforms = deepcopy(core_dataset.transforms) + self.target_transforms = deepcopy(core_dataset.target_transform) + self.core_dataset.transform = None + self.core_dataset.target_transform = None + + self.root = Path(core_dataset.root) + dataset_name = str(type(core_dataset)).split(".")[-1][:-2].lower() + self.root /= dataset_name + "_corrupted" + self.root /= f"severity_{self.shift_severity}" + self.root.mkdir(parents=True) + + if not on_the_fly: + self.to_tensor = ToTensor() + self.to_pil = ToPILImage() + self.samples = [] + self.targets = self.core_dataset.targets * 10 + self.prepare_data() + + def prepare_data(self): + with logging_redirect_tqdm(): + for corruption in tqdm(corruption_transforms): + corruption_name = corruption.__name__.lower() + (self.root / corruption_name).mkdir(parents=True) + self.save_corruption( + self.root / corruption_name, corruption(self.shift_severity) + ) + + def save_corruption(self, root: Path, corruption: nn.Module) -> None: + for i in range(self.core_length): + img, tgt = self.core_dataset[i] + if isinstance(img, str | Path): + img = Image.open(img).convert("RGB") + img = corruption(self.to_tensor(img)) + self.to_pil(img).save(root / f"{i}.png") + self.samples.append(root / f"{i}.png") + self.targets.append(tgt) + + def __len__(self): + """The length of the corrupted dataset.""" + return len(self.core_dataset) * len(corruption_transforms) + + def __getitem__(self, idx: int): + """Get the corrupted image and the target. + + Args: + idx (int): Index of the image to retrieve. + """ + if self.on_the_fly: + corrupt = corruption_transforms[idx // len(self.core_dataset)] + idx = idx % len(self.core_dataset) + img, target = self.core_dataset[idx] + + img = corrupt(img) + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + return img, target + + img, target = self.core_dataset[idx] + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + +if __name__ == "__main__": + from torchvision.datasets import CIFAR10 + + dataset = CIFAR10(root="data", download=True) + corrupted_dataset = CorruptedDataset(dataset, shift_severity=1) + print(len(corrupted_dataset)) diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index f27ffe57..5d82085b 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -3,10 +3,16 @@ import os import shutil from collections.abc import Callable +from importlib import util from pathlib import Path from typing import Literal -import cv2 +if util.find_spec("cv2"): + import cv2 + + cv2_installed = True +else: # coverage: ignore + cv2_installed = False import numpy as np import torch from einops import rearrange @@ -74,6 +80,12 @@ def __init__( MUAD cannot be used for commercial purposes. Read MUAD's license carefully before using it and verify that you can comply. """ + if not cv2_installed: # coverage: ignore + raise ImportError( + "The cv2 library is not installed. Please install" + "torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) logging.info( "MUAD is restricted to non-commercial use. By using MUAD, you " "agree to the terms and conditions." @@ -255,3 +267,7 @@ def _download(self, split: str) -> None: download_and_extract_archive( split_url, self.root, md5=self.zip_md5[split] ) + + @property + def color_palette(self) -> np.ndarray: + return self.train_id_to_color.tolist() diff --git a/torch_uncertainty/datasets/nyu.py b/torch_uncertainty/datasets/nyu.py index c67a944e..15556bfa 100644 --- a/torch_uncertainty/datasets/nyu.py +++ b/torch_uncertainty/datasets/nyu.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Literal -import cv2 import numpy as np from PIL import Image from torchvision import tv_tensors @@ -14,6 +13,13 @@ download_url, ) +if util.find_spec("cv2"): + import cv2 + + cv2_installed = True +else: # coverage: ignore + cv2_installed = False + if util.find_spec("h5py"): import h5py @@ -55,12 +61,19 @@ def __init__( max_depth (float): Maximum depth value. Defaults to 10. download (bool): Download dataset if not found. Defaults to False. """ + if not cv2_installed: # coverage: ignore + raise ImportError( + "The cv2 library is not installed. Please install" + "torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) if not h5py_installed: # coverage: ignore raise ImportError( "The h5py library is not installed. Please install" "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) + super().__init__(Path(root) / "NYUv2", transforms=transforms) self.min_depth = min_depth self.max_depth = max_depth diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 5bf3c7fa..7957f688 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -26,19 +26,65 @@ class CamVidClass(NamedTuple): class CamVid(VisionDataset): # Notes: some classes are not used here classes = [ + CamVidClass("animal", 0, (64, 128, 64)), + CamVidClass("archway", 1, (192, 0, 128)), + CamVidClass("bicyclist", 2, (0, 128, 192)), + CamVidClass("bridge", 3, (0, 128, 64)), + CamVidClass("building", 4, (128, 0, 0)), + CamVidClass("car", 5, (64, 0, 128)), + CamVidClass("cart_luggage_pram", 6, (64, 0, 192)), + CamVidClass("child", 7, (192, 128, 64)), + CamVidClass("column_pole", 8, (192, 192, 128)), + CamVidClass("fence", 9, (64, 64, 128)), + CamVidClass("lane_mkgs_driv", 10, (128, 0, 192)), + CamVidClass("lane_mkgs_non_driv", 11, (192, 0, 64)), + CamVidClass("misc_text", 12, (128, 128, 64)), + CamVidClass("motorcycle_scooter", 13, (192, 0, 192)), + CamVidClass("othermoving", 14, (128, 64, 64)), + CamVidClass("parking_block", 15, (64, 192, 128)), + CamVidClass("pedestrian", 16, (64, 64, 0)), + CamVidClass("road", 17, (128, 64, 128)), + CamVidClass("road_shoulder", 18, (128, 128, 192)), + CamVidClass("sidewalk", 19, (0, 0, 192)), + CamVidClass("sign_symbol", 20, (192, 128, 128)), + CamVidClass("sky", 21, (128, 128, 128)), + CamVidClass("suv_pickup_truck", 22, (64, 128, 192)), + CamVidClass("traffic_cone", 23, (0, 0, 64)), + CamVidClass("traffic_light", 24, (0, 64, 64)), + CamVidClass("train", 25, (192, 64, 128)), + CamVidClass("tree", 26, (128, 128, 0)), + CamVidClass("truck_bus", 27, (192, 128, 192)), + CamVidClass("tunnel", 28, (64, 0, 64)), + CamVidClass("vegetation_misc", 29, (192, 192, 0)), + CamVidClass("void", 30, (0, 0, 0)), + CamVidClass("wall", 31, (64, 192, 0)), + ] + superclasses = [ CamVidClass("sky", 0, (128, 128, 128)), CamVidClass("building", 1, (128, 0, 0)), CamVidClass("pole", 2, (192, 192, 128)), - CamVidClass("road_marking", 3, (255, 69, 0)), - CamVidClass("road", 4, (128, 64, 128)), - CamVidClass("pavement", 5, (60, 40, 222)), - CamVidClass("tree", 6, (128, 128, 0)), - CamVidClass("sign_symbol", 7, (192, 128, 128)), - CamVidClass("fence", 8, (64, 64, 128)), - CamVidClass("car", 9, (64, 0, 128)), - CamVidClass("pedestrian", 10, (64, 64, 0)), - CamVidClass("bicyclist", 11, (0, 128, 192)), - CamVidClass("unlabelled", 12, (0, 0, 0)), + CamVidClass("road", 3, (128, 64, 128)), + CamVidClass("pavement", 4, (0, 0, 192)), + CamVidClass("tree", 5, (128, 128, 0)), + CamVidClass("sign_symbol", 6, (192, 128, 128)), + CamVidClass("fence", 7, (64, 64, 128)), + CamVidClass("car", 8, (64, 0, 128)), + CamVidClass("pedestrian", 9, (64, 64, 0)), + CamVidClass("bicyclist", 10, (0, 128, 192)), + CamVidClass("void", None, (0, 0, 0)), + ] + superclasses_indices = [ + [21], + [3, 4, 31, 28, 1], + [8, 23], + [17, 10, 11], + [19, 15, 18], + [26, 29], + [20, 12, 24], + [9], + [5, 22, 27, 25, 14], + [16, 7, 6, 0], + [2, 13], ] urls = { @@ -64,6 +110,7 @@ class CamVid(VisionDataset): def __init__( self, root: str, + group_classes: bool = True, split: Literal["train", "val", "test"] | None = None, transforms: Callable | None = None, download: bool = False, @@ -73,6 +120,8 @@ def __init__( Args: root (str): Root directory of dataset where ``camvid/`` exists or will be saved to if download is set to ``True``. + group_classes (bool, optional): Whether to group the 32 classes into + 11 superclasses. Default: ``True``. split (str, optional): The dataset split, supports ``train``, ``val`` and ``test``. Default: ``None``. transforms (callable, optional): A function/transform that takes @@ -89,6 +138,16 @@ def __init__( ) super().__init__(root, transforms, None, None) + self.group_classes = group_classes + self.class_to_superclass = [] + for i in range(32): + if i == 30: # For void + self.class_to_superclass.append(None) + + for j, superclass in enumerate(self.superclasses_indices): + if i in superclass: + self.class_to_superclass.append(j) + break if download: self.download() @@ -144,9 +203,12 @@ def encode_target(self, target: Image.Image) -> torch.Tensor: colored_target = F.pil_to_tensor(target) colored_target = rearrange(colored_target, "c h w -> h w c") target = torch.zeros_like(colored_target[..., :1]) + # convert target color to index for camvid_class in self.classes: - index = camvid_class.index if camvid_class.index != 12 else 255 + index = camvid_class.index if camvid_class.index != 30 else 255 + if self.group_classes and index != 255: + index = self.class_to_superclass[index] target[ ( colored_target @@ -166,17 +228,31 @@ def decode_target(self, target: torch.Tensor) -> Image.Image: Image.Image: Decoded target as a PIL.Image. """ colored_target = repeat(target.clone(), "h w -> h w 3", c=3) - - for camvid_class in self.classes: - colored_target[ - ( - target - == torch.tensor(camvid_class.index, dtype=target.dtype) - ).all(dim=0) - ] = torch.tensor(camvid_class.color, dtype=target.dtype) - + if not self.group_classes: + for camvid_class in self.classes: + colored_target[ + ( + target + == torch.tensor(camvid_class.index, dtype=target.dtype) + ).all(dim=0) + ] = torch.tensor(camvid_class.color, dtype=target.dtype) + else: + for camvid_class in self.superclasses: + colored_target[ + ( + target + == torch.tensor(camvid_class.index, dtype=target.dtype) + ).all(dim=0) + ] = torch.tensor(camvid_class.color, dtype=target.dtype) return F.to_pil_image(rearrange(colored_target, "h w c -> c h w")) + @property + def color_palette(self) -> list[tuple[int, int, int]]: + """Return the color palette of the dataset.""" + if self.group_classes: + return [camvid_class.color for camvid_class in self.superclasses] + return [camvid_class.color for camvid_class in self.classes] + def __getitem__( self, index: int ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: diff --git a/torch_uncertainty/datasets/segmentation/cityscapes.py b/torch_uncertainty/datasets/segmentation/cityscapes.py index a30aa813..769174df 100644 --- a/torch_uncertainty/datasets/segmentation/cityscapes.py +++ b/torch_uncertainty/datasets/segmentation/cityscapes.py @@ -16,7 +16,7 @@ def __init__( root: str, split: str = "train", mode: str = "fine", - target_type: torch.List[str] | str = "instance", + target_type: list[str] | str = "instance", transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, transforms: Callable[..., Any] | None = None, @@ -120,3 +120,8 @@ def plot_sample( The axis on which the sample was plotted. """ raise NotImplementedError("This method is not implemented yet.") + + @property + def color_palette(self) -> list[tuple[int, int, int]]: + """Return the color palette of the dataset.""" + return [c.color for c in self.classes] diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 094b4834..3858300e 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -1,12 +1,15 @@ +import math from typing import Any +import torch from einops import rearrange from torch import Tensor, nn +from torch.nn import functional as F from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t def check_packed_parameters_consistency( - alpha: int, gamma: int, num_estimators: int + alpha: float, gamma: int, num_estimators: int ) -> None: """Check the consistency of the parameters of the Packed-Ensembles layers. @@ -49,35 +52,38 @@ def __init__( self, in_features: int, out_features: int, - alpha: int, + alpha: float, num_estimators: int, gamma: int = 1, bias: bool = True, - rearrange: bool = True, first: bool = False, last: bool = False, + implementation: str = "legacy", + rearrange: bool = True, device=None, dtype=None, ) -> None: r"""Packed-Ensembles-style Linear layer. This layer computes fully-connected operation for a given number of - estimators (:attr:`num_estimators`) using a `1x1` convolution. + estimators (:attr:`num_estimators`). Args: in_features (int): Number of input features of the linear layer. out_features (int): Number of channels produced by the linear layer. - alpha (int): The width multiplier of the linear layer. + alpha (float): The width multiplier of the linear layer. num_estimators (int): The number of estimators grouped in the layer. gamma (int, optional): Defaults to ``1``. bias (bool, optional): It ``True``, adds a learnable bias to the output. Defaults to ``True``. - rearrange (bool, optional): Rearrange the input and outputs for - compatibility with previous and later layers. Defaults to ``True``. first (bool, optional): Whether this is the first layer of the network. Defaults to ``False``. last (bool, optional): Whether this is the last layer of the network. Defaults to ``False``. + implementation (str, optional): The implementation to use. Defaults + to ``"legacy"``. + rearrange (bool, optional): Rearrange the input and outputs for + compatibility with previous and later layers. Defaults to ``True``. device (torch.device, optional): The device to use for the layer's parameters. Defaults to ``None``. dtype (torch.dtype, optional): The dtype to use for the layer's @@ -110,6 +116,7 @@ def __init__( self.first = first self.num_estimators = num_estimators self.rearrange = rearrange + self.implementation = implementation # Define the number of features of the underlying convolution extended_in_features = int(in_features * (1 if first else alpha)) @@ -130,42 +137,93 @@ def __init__( actual_groups ) - self.conv1x1 = nn.Conv1d( - in_channels=extended_in_features, - out_channels=extended_out_features, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=actual_groups, - bias=bias, - padding_mode="zeros", - **factory_kwargs, - ) + # FIXME: This is a temporary check + assert implementation in [ + "legacy", + "sparse", + "full", + "einsum", + ], f"Unknown implementation: {implementation} for PackedLinear" + + if self.implementation == "legacy": + self.weight = nn.Parameter( + torch.empty( + ( + extended_out_features, + extended_in_features // actual_groups, + 1, + ), + **factory_kwargs, + ) + ) + else: + self.weight = nn.Parameter( + torch.empty( + ( + actual_groups, + extended_out_features // actual_groups, + extended_in_features // actual_groups, + ), + **factory_kwargs, + ) + ) + + self.in_features = extended_in_features // actual_groups + self.out_features = extended_out_features // actual_groups + self.groups = actual_groups + + if bias: + self.bias = nn.Parameter( + torch.empty(extended_out_features, **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.implementation == "legacy": + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + else: + for n in range(self.groups): + nn.init.kaiming_uniform_(self.weight[n], a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + if self.implementation == "sparse": + self.weight = nn.Parameter( + torch.block_diag(*self.weight).to_sparse() + ) def _rearrange_forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(-1) if not self.first: x = rearrange(x, "(m e) c h -> e (m c) h", m=self.num_estimators) - - x = self.conv1x1(x) + x = F.conv1d(x, self.weight, self.bias, 1, 0, 1, self.groups) x = rearrange(x, "e (m c) h -> (m e) c h", m=self.num_estimators) return x.squeeze(-1) def forward(self, inputs: Tensor) -> Tensor: - if self.rearrange: - return self._rearrange_forward(inputs) - return self.conv1x1(inputs) - - @property - def weight(self) -> Tensor: - r"""The weight of the underlying convolutional layer.""" - return self.conv1x1.weight - - @property - def bias(self) -> Tensor | None: - r"""The bias of the underlying convolutional layer.""" - return self.conv1x1.bias + if self.implementation == "legacy": + if self.rearrange: + return self._rearrange_forward(inputs) + return F.conv1d( + inputs, self.weight, self.bias, 1, 0, 1, self.groups + ) + if self.implementation == "full": + block_diag = torch.block_diag(*self.weight) + return F.linear(inputs, block_diag, self.bias) + if self.implementation == "sparse": + return (inputs @ self.weight.transpose(0, 1)) + self.bias + if self.implementation == "einsum": + return torch.einsum( + "bki,kij->bkj", + inputs.view(-1, self.groups, self.in_features), + self.weight.transpose(1, 2), + ).flatten(start_dim=-2, end_dim=-1) + raise ValueError(f"Unknown implementation: {self.implementation}") class PackedConv1d(nn.Module): diff --git a/torch_uncertainty/losses/__init__.py b/torch_uncertainty/losses/__init__.py index 318295e1..2257e52f 100644 --- a/torch_uncertainty/losses/__init__.py +++ b/torch_uncertainty/losses/__init__.py @@ -1,4 +1,9 @@ # ruff: noqa: F401 from .bayesian import ELBOLoss, KLDiv -from .classification import ConfidencePenaltyLoss, ConflictualLoss, DECLoss +from .classification import ( + ConfidencePenaltyLoss, + ConflictualLoss, + DECLoss, + FocalLoss, +) from .regression import BetaNLL, DERLoss, DistributionNLLLoss diff --git a/torch_uncertainty/losses/classification.py b/torch_uncertainty/losses/classification.py index 0b9230b9..2d74d852 100644 --- a/torch_uncertainty/losses/classification.py +++ b/torch_uncertainty/losses/classification.py @@ -153,8 +153,8 @@ def forward( annealing_coef = self.reg_weight else: annealing_coef = torch.min( - torch.tensor(1.0, dtype=evidence.dtype), - torch.tensor( + input=torch.tensor(1.0, dtype=evidence.dtype), + other=torch.tensor( current_epoch / self.annealing_step, dtype=evidence.dtype ), ) @@ -195,6 +195,7 @@ def __init__( if reduction not in ("none", "mean", "sum"): raise ValueError(f"{reduction} is not a valid value for reduction.") self.reduction = reduction + if eps < 0: raise ValueError( "The epsilon value should be non-negative, but got " f"{eps}." @@ -279,3 +280,59 @@ def forward(self, logits: Tensor, targets: Tensor) -> Tensor: if self.reduction == "mean": return ce_loss + self.reg_weight * reg_loss.mean() return ce_loss + self.reg_weight * reg_loss + + +class FocalLoss(nn.Module): + def __init__( + self, + gamma: float, + alpha: Tensor | None = None, + reduction: str = "mean", + ) -> None: + """Focal-Loss for classification tasks. + + Args: + gamma (float, optional): A constant, as described in the paper. + alpha (Tensor, optional): Weights for each class. Defaults to None. + reduction (str, optional): 'mean', 'sum' or 'none'. + Defaults to 'mean'. + + Reference: + Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). + Focal Loss for Dense Object Detection. TPAMI 2020. + + Implementation: + Inspired by github.com/AdeelH/pytorch-multi-class-focal-loss. + """ + if reduction not in ("none", "mean", "sum") and reduction is not None: + raise ValueError(f"{reduction} is not a valid value for reduction.") + self.reduction = reduction + + if gamma < 0: + raise ValueError( + "The gamma term of the focal loss should be non-negative, but got " + f"{gamma}." + ) + self.gamma = gamma + + super().__init__() + self.alpha = alpha + self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none") + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + log_p = F.log_softmax(x, dim=-1) + ce = self.nll_loss(log_p, y) + + all_rows = torch.arange(len(x)) + log_pt = log_p[all_rows, y] + + pt = log_pt.exp() + focal_term = (1 - pt) ** self.gamma + + loss = focal_term * ce + + if self.reduction == "mean": + return loss.mean() + if self.reduction == "sum": + return loss.sum() + return loss diff --git a/torch_uncertainty/metrics/classification/calibration_error.py b/torch_uncertainty/metrics/classification/calibration_error.py index 512323b3..e361d499 100644 --- a/torch_uncertainty/metrics/classification/calibration_error.py +++ b/torch_uncertainty/metrics/classification/calibration_error.py @@ -1,73 +1,218 @@ from typing import Any, Literal import matplotlib.pyplot as plt +import matplotlib.ticker as mticker +import numpy as np +import seaborn as sns import torch from torchmetrics.classification.calibration_error import ( BinaryCalibrationError, MulticlassCalibrationError, ) +from torchmetrics.functional.classification.calibration_error import ( + _binning_bucketize, +) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel -from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +from torchmetrics.utilities.plot import _PLOT_OUT_TYPE from .adaptive_calibration_error import AdaptiveCalibrationError -def _ce_plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: - fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) +def _reliability_diagram_subplot( + ax, + accuracies: np.ndarray, + confidences: np.ndarray, + bin_sizes: np.ndarray, + bins: np.ndarray, + title: str = "Reliability Diagram", + xlabel: str = "Top-class Confidence (%)", + ylabel: str = "Success Rate (%)", +) -> None: + widths = 1.0 / len(bin_sizes) + positions = bins + widths / 2.0 + alphas = 0.2 + 0.8 * bin_sizes - conf = dim_zero_cat(self.confidences) - acc = dim_zero_cat(self.accuracies) - bin_width = 1 / self.n_bins + colors = np.zeros((len(bin_sizes), 4)) + colors[:, 0] = 240 / 255.0 + colors[:, 1] = 60 / 255.0 + colors[:, 2] = 60 / 255.0 + colors[:, 3] = alphas - bin_ids = torch.round( - torch.clamp(conf * self.n_bins, 1e-5, self.n_bins - 1 - 1e-5) + gap_plt = ax.bar( + positions, + np.abs(accuracies - confidences), + bottom=np.minimum(accuracies, confidences), + width=widths, + edgecolor=colors, + color=colors, + linewidth=1, + label="Gap", ) - val, inverse, counts = bin_ids.unique( - return_inverse=True, return_counts=True + + acc_plt = ax.bar( + positions, + 0, + bottom=accuracies, + width=widths, + edgecolor="black", + color="black", + alpha=1.0, + linewidth=3, + label="Accuracy", ) - counts = counts.float() - val_oh = torch.nn.functional.one_hot( - val.long(), num_classes=self.n_bins - ).float() - - # add 1e-6 to avoid division NaNs - values = ( - val_oh.T - @ torch.sum( - acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(), - 0, - ) - / (val_oh.T @ counts + 1e-6) + + ax.set_aspect("equal") + ax.plot([0, 1], [0, 1], linestyle="--", color="gray") + + gaps = np.abs(accuracies - confidences) + ece = (np.sum(gaps * bin_sizes) / np.sum(bin_sizes)) * 100 + + ax.text( + 0.98, + 0.02, + f"ECE={ece:.03}%", + color="black", + ha="right", + va="bottom", + transform=ax.transAxes, ) - plt.rc("axes", axisbelow=True) - ax.hist( - x=[bin_width * i * 100 for i in range(self.n_bins)], - weights=values.cpu() * 100, - bins=[bin_width * i * 100 for i in range(self.n_bins + 1)], - alpha=0.7, - linewidth=1, - edgecolor="#0d559f", - color="#1f77b4", + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + ax.set_title(title) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + ax.legend(handles=[gap_plt, acc_plt]) + + +def _confidence_histogram_subplot( + ax, + accuracies: np.ndarray, + confidences: np.ndarray, + title="Examples per bin", + xlabel="Top-class Confidence (%)", + ylabel="Density", +) -> None: + sns.kdeplot( + confidences, + linewidth=2, + ax=ax, + fill=True, + alpha=0.5, + ) + + ax.set_xlim(0, 1) + ax.set_title(title) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + avg_acc = np.mean(accuracies) + avg_conf = np.mean(confidences) + + acc_plt = ax.axvline( + x=avg_acc, + ls="solid", + lw=3, + c="black", + label="Accuracy", + ) + conf_plt = ax.axvline( + x=avg_conf, + ls="dotted", + lw=3, + c="#444", + label="Avg. confidence", + ) + ax.legend(handles=[acc_plt, conf_plt], loc="upper left") + + +def reliability_chart( + accuracies: np.ndarray, + confidences: np.ndarray, + bin_accuracies: np.ndarray, + bin_confidences: np.ndarray, + bin_sizes: np.ndarray, + bins: np.ndarray, + title="Reliability Diagram", + figsize=(6, 6), + dpi=72, +) -> _PLOT_OUT_TYPE: + """Builds Reliability Diagram + `Source `_. + """ + figsize = (figsize[0], figsize[0] * 1.4) + + fig, ax = plt.subplots( + nrows=2, + ncols=1, + sharex=True, + figsize=figsize, + dpi=dpi, + gridspec_kw={"height_ratios": [4, 1]}, ) - ax.plot([0, 100], [0, 100], "--", color="#0d559f") - plt.grid(True, linestyle="--", alpha=0.7, zorder=0) - ax.set_xlabel("Top-class Confidence (%)", fontsize=16) - ax.set_ylabel("Success Rate (%)", fontsize=16) - ax.set_xlim(0, 100) - ax.set_ylim(0, 100) - ax.set_aspect("equal", "box") - if fig is not None: - fig.tight_layout() + plt.tight_layout() + plt.subplots_adjust(hspace=0) + + # reliability diagram subplot + _reliability_diagram_subplot( + ax[0], + bin_accuracies, + bin_confidences, + bin_sizes, + bins, + title=title, + ) + + # confidence histogram subplot + _confidence_histogram_subplot(ax[1], accuracies, confidences, title="") + + new_ticks = np.abs(ax[1].get_yticks()).astype(np.int32) + ax[1].yaxis.set_major_locator(mticker.FixedLocator(new_ticks)) + ax[1].set_yticklabels(new_ticks) + return fig, ax +def custom_plot(self) -> _PLOT_OUT_TYPE: + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + + bin_boundaries = torch.linspace( + 0, + 1, + self.n_bins + 1, + dtype=torch.float, + device=confidences.device, + ) + + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _binning_bucketize( + confidences, accuracies, bin_boundaries + ) + + np_acc_bin = acc_bin.cpu().numpy() + np_conf_bin = conf_bin.cpu().numpy() + np_prop_bin = prop_bin.cpu().numpy() + np_bin_boundaries = bin_boundaries.cpu().numpy() + + return reliability_chart( + accuracies=accuracies.cpu().numpy(), + confidences=confidences.cpu().numpy(), + bin_accuracies=np_acc_bin, + bin_confidences=np_conf_bin, + bin_sizes=np_prop_bin, + bins=np_bin_boundaries, + ) + + # overwrite the plot method of the original metrics -BinaryCalibrationError.plot = _ce_plot -MulticlassCalibrationError.plot = _ce_plot +BinaryCalibrationError.plot = custom_plot +MulticlassCalibrationError.plot = custom_plot class CalibrationError: diff --git a/torch_uncertainty/metrics/classification/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py index 155e111f..bed2bd24 100644 --- a/torch_uncertainty/metrics/classification/grouping_loss.py +++ b/torch_uncertainty/metrics/classification/grouping_loss.py @@ -1,5 +1,14 @@ +from importlib import util + import torch -from glest import GLEstimator as GLEstimatorBase + +if util.find_spec("glest"): + from glest import GLEstimator as GLEstimatorBase + + glest_installed = True +else: # coverage: ignore + glest_installed = False + from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities import rank_zero_warn @@ -55,6 +64,13 @@ def __init__( networks. In ICLR 2023. """ super().__init__(**kwargs) + if not glest_installed: # coverage: ignore + raise ImportError( + "The glest library is not installed. Please install" + "torch_uncertainty with the all option:" + """pip install -U "torch_uncertainty[all]".""" + ) + self.estimator = GLEstimator(None) self.add_state("probs", default=[], dist_reduce_fx="cat") @@ -81,7 +97,7 @@ def update(self, probs: Tensor, target: Tensor, features: Tensor) -> None: (batch, num_estimators, num_features) or (batch, num_features) """ if target.ndim == 2: - target = target.argmax(axis=-1) + target = target.argmax(dim=-1) elif target.ndim != 1: raise ValueError( "Expected `target` to be of shape (batch) or (batch, num_classes) " diff --git a/torch_uncertainty/models/depth/bts.py b/torch_uncertainty/models/depth/bts.py index 3284b69a..bead997f 100644 --- a/torch_uncertainty/models/depth/bts.py +++ b/torch_uncertainty/models/depth/bts.py @@ -532,7 +532,7 @@ def forward(self, features: list[Tensor]) -> Tensor | Distribution: """Forward pass. Args: - features (list[Tensor]): List of the features from the backbone. + features (list[Tensor]): list of the features from the backbone. Note: Depending of the :attr:`dist_layer` of the backbone, the output can diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 720ce7f0..0dd17547 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -28,7 +28,7 @@ def __init__( Args: in_features (int): Number of input features. num_outputs (int): Number of output features. - hidden_dims (List[int]): Number of features for each hidden layer. + hidden_dims (list[int]): Number of features for each hidden layer. layer (nn.Module): Layer class. activation (Callable): Activation function. layer_args (Dict): Arguments for the layer class. @@ -131,7 +131,7 @@ def mlp( Args: in_features (int): Number of input features. num_outputs (int): Number of output features. - hidden_dims (List[int]): Number of features in each hidden layer. + hidden_dims (list[int]): Number of features in each hidden layer. layer (nn.Module, optional): Layer type. Defaults to nn.Linear. activation (Callable, optional): Activation function. Defaults to F.relu. diff --git a/torch_uncertainty/models/utils.py b/torch_uncertainty/models/utils.py index ecf06466..7722b36e 100644 --- a/torch_uncertainty/models/utils.py +++ b/torch_uncertainty/models/utils.py @@ -11,7 +11,7 @@ def __init__(self, model: nn.Module, feat_names: list[str]) -> None: Args: model (nn.Module): Base model. - feat_names (list[str]): List of the feature names. + feat_names (list[str]): list of the feature names. """ super().__init__() self.model = model @@ -24,7 +24,7 @@ def forward(self, x: Tensor) -> list[Tensor]: x (Tensor): Input tensor. Returns: - list[Tensor]: List of the features. + list[Tensor]: list of the features. """ feature = x features = [] diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index ae9970d5..25195925 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -87,7 +87,7 @@ def __init__( num_classes: int, conv_bias: bool, dropout_rate: float, - num_estimators: int = 4, + num_estimators: int, alpha: int = 2, gamma: int = 1, groups: int = 1, diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 92d299e8..23589f8c 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -55,7 +55,9 @@ def __init__( super().__init__() if not laplace_installed: # coverage: ignore raise ImportError( - "The laplace-torch library is not installed. Please install it via `pip install laplace-torch`." + "The laplace-torch library is not installed. Please install" + "torch_uncertainty with the all option:" + """pip install -U "torch_uncertainty[all]".""" ) self.pred_type = pred_type diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index c2ae660c..e76fc65a 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -70,6 +70,7 @@ def __init__( optim_recipe: dict | Optimizer | None = None, mixup_params: dict | None = None, eval_ood: bool = False, + eval_shift: bool = False, eval_grouping_loss: bool = False, ood_criterion: Literal[ "msp", "logit", "energy", "entropy", "mi", "vr" @@ -94,10 +95,12 @@ def __init__( optionally the scheduler to use. Defaults to ``None``. mixup_params (dict, optional): Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, - mixup alpha, and cutmix alpha. If None, no augmentations. + mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD - detection performance or not. Defaults to ``False``. + detection performance. Defaults to ``False``. + eval_shift (bool, optional): Indicates whether to evaluate the Distribution + shift performance. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. ood_criterion (str, optional): OOD criterion. Available options are @@ -145,6 +148,7 @@ def __init__( self.num_classes = num_classes self.eval_ood = eval_ood + self.eval_shift = eval_shift self.eval_grouping_loss = eval_grouping_loss self.ood_criterion = ood_criterion self.log_plots = log_plots @@ -228,6 +232,9 @@ def _init_metrics(self) -> None: self.test_ood_metrics = ood_metrics.clone(prefix="ood/") self.test_ood_entropy = Entropy() + if self.eval_shift: + self.test_shift_metrics = cls_metrics.clone(prefix="shift/") + # metrics for ensembles only if self.is_ensemble: ens_metrics = MetricCollection( @@ -243,6 +250,11 @@ def _init_metrics(self) -> None: if self.eval_ood: self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") + if self.eval_shift: + self.test_shift_ens_metrics = ens_metrics.clone( + prefix="shift/ens_" + ) + if self.eval_grouping_loss: grouping_loss = MetricCollection( {"cls/grouping_loss": GroupingLoss()} @@ -488,7 +500,7 @@ def test_step( pp_probs = pp_logits self.post_cls_metrics.update(pp_probs, targets) - elif self.eval_ood and dataloader_idx == 1: + if self.eval_ood and dataloader_idx == 1: self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_ood_entropy(probs) self.log( @@ -503,15 +515,20 @@ def test_step( if self.ood_logit_storage is not None: self.ood_logit_storage.append(logits.detach().cpu()) + if self.eval_shift and dataloader_idx == (2 if self.eval_ood else 1): + self.test_shift_metrics.update(probs, targets) + if self.is_ensemble: + self.test_shift_ens_metrics.update(probs_per_est) + def on_validation_epoch_end(self) -> None: - self.log_dict( - self.val_cls_metrics.compute(), logger=True, sync_dist=True - ) + res_dict = self.val_cls_metrics.compute() + self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "Acc%", - self.val_cls_metrics["cls/Acc"].compute() * 100, + res_dict["val/cls/Acc"] * 100, prog_bar=True, logger=False, + sync_dist=True, ) self.val_cls_metrics.reset() @@ -557,6 +574,20 @@ def on_test_epoch_end(self) -> None: self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) + if self.eval_shift: + tmp_metrics = self.test_shift_metrics.compute() + shift_severity = self.trainer.test_dataloaders[ + 2 if self.eval_ood else 1 + ].dataset.shift_severity + tmp_metrics["shift/shift_severity"] = shift_severity + self.log_dict(tmp_metrics, sync_dist=True) + result_dict.update(tmp_metrics) + + if self.is_ensemble: + tmp_metrics = self.test_shift_ens_metrics.compute() + self.log_dict(tmp_metrics, sync_dist=True) + result_dict.update(tmp_metrics) + if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( "Reliabity diagram", self.test_cls_metrics["cal/ECE"].plot()[0] diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index bd5cd6eb..a6749bf0 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -55,6 +55,7 @@ def __init__( is_ensemble: bool = False, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, + eval_shift: bool = False, num_image_plot: int = 4, log_plots: bool = False, ) -> None: @@ -70,6 +71,8 @@ def __init__( Defaults to ``False``. optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. + eval_shift (bool, optional): Indicates whether to evaluate the Distribution + shift performance. Defaults to ``False``. format_batch_fn (nn.Module, optional): The function to format the batch. Defaults to ``None``. num_image_plot (int, optional): Number of images to plot. Defaults to ``4``. @@ -78,6 +81,11 @@ def __init__( """ super().__init__() _depth_routine_checks(output_dim, num_image_plot, log_plots) + if eval_shift: + raise NotImplementedError( + "Distribution shift evaluation not implemented yet. Raise an issue " + "if needed." + ) self.model = model self.output_dim = output_dim @@ -287,7 +295,15 @@ def test_step( self.test_prob_metrics.update(mixture, targets, padding_mask) def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_metrics.compute(), sync_dist=True) + res_dict = self.val_metrics.compute() + self.log_dict(res_dict, logger=True, sync_dist=True) + self.log( + "RMSE", + res_dict["val/reg/RMSE"], + prog_bar=True, + logger=False, + sync_dist=True, + ) self.val_metrics.reset() if self.probabilistic: self.log_dict( diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 12538db5..a90f08b2 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -35,6 +35,7 @@ def __init__( loss: nn.Module, is_ensemble: bool = False, optim_recipe: dict | Optimizer | None = None, + eval_shift: bool = False, format_batch_fn: nn.Module | None = None, ) -> None: r"""Routine for training & testing on **regression** tasks. @@ -49,6 +50,8 @@ def __init__( Defaults to ``False``. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. + eval_shift (bool, optional): Indicates whether to evaluate the Distribution + shift performance. Defaults to ``False``. format_batch_fn (torch.nn.Module, optional): The function to format the batch. Defaults to ``None``. @@ -67,6 +70,11 @@ def __init__( """ super().__init__() _regression_routine_checks(output_dim) + if eval_shift: + raise NotImplementedError( + "Distribution shift evaluation not implemented yet. Raise an issue " + "if needed." + ) self.model = model self.probabilistic = probabilistic @@ -234,7 +242,15 @@ def test_step( self.test_prob_metrics.update(mixture, targets) def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_metrics.compute(), sync_dist=True) + res_dict = self.val_metrics.compute() + self.log_dict(res_dict, logger=True, sync_dist=True) + self.log( + "RMSE", + res_dict["val/reg/RMSE"], + prog_bar=True, + logger=False, + sync_dist=True, + ) self.val_metrics.reset() if self.probabilistic: self.log_dict(self.val_prob_metrics.compute(), sync_dist=True) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 03fa2f58..185b004a 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -6,7 +6,9 @@ from torch import Tensor, nn from torch.optim import Optimizer from torchmetrics import Accuracy, MetricCollection +from torchvision.transforms.v2 import ToDtype from torchvision.transforms.v2 import functional as F +from torchvision.utils import draw_segmentation_masks from torch_uncertainty.metrics import ( AUGRC, @@ -20,6 +22,7 @@ EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) +from torch_uncertainty.utils.plotting import show class SegmentationRoutine(LightningModule): @@ -29,9 +32,11 @@ def __init__( num_classes: int, loss: nn.Module, optim_recipe: dict | Optimizer | None = None, + eval_shift: bool = False, format_batch_fn: nn.Module | None = None, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, + num_samples_to_plot: int = 3, num_calibration_bins: int = 15, ) -> None: r"""Routine for training & testing on **segmentation** tasks. @@ -42,12 +47,16 @@ def __init__( loss (torch.nn.Module): Loss function to optimize the :attr:`model`. optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. + eval_shift (bool, optional): Indicates whether to evaluate the Distribution + shift performance. Defaults to ``False``. format_batch_fn (torch.nn.Module, optional): The function to format the batch. Defaults to ``None``. metric_subsampling_rate (float, optional): The rate of subsampling for the memory consuming metrics. Defaults to ``1e-2``. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. + num_samples_to_plot (int, optional): Number of samples to plot in the + segmentation results. Defaults to ``3``. num_calibration_bins (int, optional): Number of bins to compute calibration metrics. Defaults to ``15``. @@ -65,6 +74,11 @@ def __init__( metric_subsampling_rate, num_calibration_bins, ) + if eval_shift: + raise NotImplementedError( + "Distribution shift evaluation not implemented yet. Raise an issue " + "if needed." + ) self.model = model self.num_classes = num_classes @@ -126,6 +140,10 @@ def __init__( self.test_seg_metrics = seg_metrics.clone(prefix="test/") self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test/") + if log_plots: + self.num_samples_to_plot = num_samples_to_plot + self.sample_buffer = [] + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -202,11 +220,27 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST, ) + logits = rearrange( - logits, "(m b) c h w -> (b h w) m c", b=targets.size(0) + logits, "(m b) c h w -> b m c h w", b=targets.size(0) ) - probs_per_est = logits.softmax(dim=-1) + probs_per_est = logits.softmax(dim=2) probs = probs_per_est.mean(dim=1) + + if ( + self.log_plots + and len(self.sample_buffer) < self.num_samples_to_plot + ): + max_count = self.num_samples_to_plot - len(self.sample_buffer) + for i, (_img, _prb, _tgt) in enumerate( + zip(img, probs, targets, strict=False) + ): + if i >= max_count: + break + _pred = _prb.argmax(dim=0, keepdim=True) + self.sample_buffer.append((_img, _pred, _tgt)) + + probs = rearrange(probs, "b c h w -> (b h w) c") targets = targets.flatten() valid_mask = targets != 255 probs, targets = probs[valid_mask], targets[valid_mask] @@ -214,13 +248,13 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) def on_validation_epoch_end(self) -> None: - self.log_dict( - self.val_seg_metrics.compute(), logger=True, sync_dist=True - ) + res_dict = self.val_seg_metrics.compute() + self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "mIoU%", - self.val_seg_metrics["seg/mIoU"].compute() * 100, + res_dict["val/seg/mIoU"] * 100, prog_bar=True, + sync_dist=True, ) self.log_dict(self.val_sbsmpl_seg_metrics.compute(), sync_dist=True) self.val_seg_metrics.reset() @@ -231,17 +265,58 @@ def on_test_epoch_end(self) -> None: self.log_dict(self.test_sbsmpl_seg_metrics.compute(), sync_dist=True) if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( - "Reliabity diagram", + "Calibration/Reliabity diagram", self.test_sbsmpl_seg_metrics["cal/ECE"].plot()[0], ) self.logger.experiment.add_figure( - "Risk-Coverage curve", + "Selective Classification/Risk-Coverage curve", self.test_sbsmpl_seg_metrics["sc/AURC"].plot()[0], ) self.logger.experiment.add_figure( - "Generalized Risk-Coverage curve", + "Selective Classification/Generalized Risk-Coverage curve", self.test_sbsmpl_seg_metrics["sc/AUGRC"].plot()[0], ) + self.log_segmentation_plots() + + def log_segmentation_plots(self) -> None: + """Builds and logs examples of segmentation plots from the test set.""" + for i, (img, pred, tgt) in enumerate(self.sample_buffer): + pred = ( + pred + == torch.arange(self.num_classes, device=pred.device)[ + :, None, None + ] + ) + tgt = ( + tgt + == torch.arange(self.num_classes, device=tgt.device)[ + :, None, None + ] + ) + + # Undo normalization on the image and convert to uint8. + mean = torch.tensor(self.trainer.datamodule.mean, device=img.device) + std = torch.tensor(self.trainer.datamodule.std, device=img.device) + img = img * std[:, None, None] + mean[:, None, None] + img = ToDtype(torch.uint8, scale=True)(img) + + dataset = self.trainer.datamodule.test + if hasattr(dataset, "color_palette"): + color_palette = dataset.color_palette + else: + color_palette = None + + pred_mask = draw_segmentation_masks( + img, pred, alpha=0.7, colors=color_palette + ) + gt_mask = draw_segmentation_masks( + img, tgt, alpha=0.7, colors=color_palette + ) + + self.logger.experiment.add_figure( + f"Segmentation results/{i}", + show(pred_mask, gt_mask), + ) def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: total_size = target.size(0) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 72bf6700..70a59267 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -1,5 +1,6 @@ """Adapted from https://github.com/hendrycks/robustness.""" +import ctypes from importlib import util from io import BytesIO @@ -22,6 +23,14 @@ else: # coverage: ignore skimage_installed = False +if util.find_spec("scipy"): + from scipy.ndimage import map_coordinates + from scipy.ndimage import zoom as scizoom + + scipy_installed = True +else: # coverage: ignore + scipy_installed = False + from torch import Tensor, nn from torchvision.transforms import ( InterpolationMode, @@ -31,157 +40,167 @@ ToTensor, ) +if util.find_spec("wand"): + from wand.api import library as wandlibrary + from wand.image import Image as WandImage + + wandlibrary.MagickMotionBlurImage.argtypes = ( + ctypes.c_void_p, # wand + ctypes.c_double, # radius + ctypes.c_double, # sigma + ctypes.c_double, + ) # angle + + class MotionImage(WandImage): + def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): + wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle) + + wand_installed = True +else: # coverage: ignore + wand_installed = False + from torch_uncertainty.datasets import FrostImages +from .image import Brightness as IBrightness +from .image import Contrast as IContrast +from .image import Saturation as ISaturation + __all__ = [ - "DefocusBlur", - "Frost", - "GaussianBlur", "GaussianNoise", - "GlassBlur", + "ShotNoise", "ImpulseNoise", - "JPEGCompression", + "DefocusBlur", + "GlassBlur", + "MotionBlur", + "ZoomBlur", + "Snow", + "Frost", + "Fog", + "Brightness", + "Contrast", + "Elastic", "Pixelate", - "ShotNoise", + "JPEGCompression", + "GaussianBlur", "SpeckleNoise", + "Saturation", + "corruption_transforms", ] -class GaussianNoise(nn.Module): +class TUCorruption(nn.Module): def __init__(self, severity: int) -> None: + """Base class for corruptions.""" super().__init__() if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") if not isinstance(severity, int): raise TypeError("Severity must be an integer.") self.severity = severity + + def __repr__(self) -> str: + """Printable representation.""" + return self.__class__.__name__ + f"(severity={self.severity})" + + +class GaussianNoise(TUCorruption): + def __init__(self, severity: int) -> None: + """Add Gaussian noise to an image. + + Args: + severity (int): Severity level of the corruption. + """ + super().__init__(severity) self.scale = [0, 0.04, 0.06, 0.08, 0.09, 0.10][severity] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clip(torch.normal(img, self.scale), 0, 1) - - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" + return torch.clamp(torch.normal(img, self.scale), 0, 1) -class ShotNoise(nn.Module): +class ShotNoise(TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity + """Add shot noise to an image. + + Args: + severity (int): Severity level of the corruption. + """ + super().__init__(severity) self.scale = [500, 250, 100, 75, 50][severity - 1] def forward(self, img: Tensor): if self.severity == 0: return img - return torch.clip(torch.poisson(img * self.scale) / self.scale, 0, 1) - - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" + return torch.clamp(torch.poisson(img * self.scale) / self.scale, 0, 1) -class ImpulseNoise(nn.Module): +class ImpulseNoise(TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() + """Add impulse noise to an image. + + Args: + severity (int): Severity level of the corruption. + """ + super().__init__(severity) if not skimage_installed: # coverage: ignore raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity self.scale = [0, 0.01, 0.02, 0.03, 0.05, 0.07][severity] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clip( + return torch.clamp( torch.as_tensor(random_noise(img, mode="s&p", amount=self.scale)), - 0, - 1, + torch.zeros(1), + torch.ones(1), ) - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" - -class SpeckleNoise(nn.Module): +class DefocusBlur(TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity - self.scale = [0.06, 0.1, 0.12, 0.16, 0.2][severity - 1] - - def forward(self, img: Tensor) -> Tensor: - if self.severity == 0: - return img - return torch.clip( - img + img * torch.normal(img, self.scale), - 0, - 1, - ) + """Add defocus blur to an image. - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" - - -class GaussianBlur(nn.Module): - def __init__(self, severity: int) -> None: - super().__init__() - if not skimage_installed: # coverage: ignore + Args: + severity (int): Severity level of the corruption. + """ + super().__init__(severity) + if not cv2_installed: # coverage: ignore raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity - self.sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1] + self.radius = [3, 4, 6, 8, 10][severity - 1] + self.alias_blur = [0.1, 0.5, 0.5, 0.5, 0.5][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clip( - torch.as_tensor(gaussian(img, sigma=self.sigma)), - 0, - 1, - ) - - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" + img = img.numpy() + channels = [ + torch.as_tensor( + cv2.filter2D( + img[ch, :, :], + -1, + disk(self.radius, alias_blur=self.alias_blur), + ) + ) + for ch in range(3) + ] + return torch.clamp(torch.stack(channels), 0, 1) -class GlassBlur(nn.Module): # TODO: batch +class GlassBlur(TUCorruption): # TODO: batch def __init__(self, severity: int) -> None: - super().__init__() + super().__init__(severity) if not skimage_installed or not cv2_installed: # coverage: ignore raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity self.sigma = [0.05, 0.25, 0.4, 0.25, 0.4][severity - 1] self.max_delta = 1 self.iterations = [1, 1, 1, 2, 2][severity - 1] @@ -202,14 +221,10 @@ def forward(self, img: Tensor) -> Tensor: img[h_prime, w_prime], img[h, w], ) - return torch.clip( + return torch.clamp( torch.as_tensor(gaussian(img, sigma=self.sigma)), 0, 1 ) - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" - def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): if radius <= 8: @@ -225,73 +240,279 @@ def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) -class DefocusBlur(nn.Module): +class MotionBlur(TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() - if not cv2_installed: # coverage: ignore + super().__init__(severity) + self.rng = np.random.default_rng() + self.radius = [10, 15, 15, 15, 20][severity - 1] + self.sigma = [3, 5, 8, 12, 15][severity - 1] + self.to_pil_img = ToPILImage() + self.to_tensor = ToTensor() + + if not wand_installed: # coverage: ignore raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity - self.radius = [0.3, 0.4, 0.5, 1, 1.5][severity - 1] - self.alias_blur = [0.4, 0.5, 0.6, 0.2, 0.1][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - img = np.array(img) - channels = [ - torch.as_tensor( - cv2.filter2D( - img[d, :, :], - -1, - disk(self.radius, alias_blur=self.alias_blur), - ) + output = BytesIO() + pil_img = self.to_pil_img(img) + pil_img.save(output, "PNG") + x = MotionImage(blob=output.getvalue()) + x.motion_blur( + radius=self.radius, + sigma=self.sigma, + angle=self.rng.uniform(-45, 45), + ) + x = cv2.imdecode( + np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED + ) + x = np.clip(x[..., [2, 1, 0]], 0, 255) + return self.to_tensor(x) + + +def clipped_zoom(img, zoom_factor): + h = img.shape[0] + # ceil crop height(= crop width) + ch = int(np.ceil(h / zoom_factor)) + + top = (h - ch) // 2 + img = scizoom( + img[top : top + ch, top : top + ch], + (zoom_factor, zoom_factor, 1), + order=1, + ) + # trim off any extra pixels + trim_top = (img.shape[0] - h) // 2 + + return img[trim_top : trim_top + h, trim_top : trim_top + h] + + +class ZoomBlur(TUCorruption): + def __init__(self, severity: int) -> None: + super().__init__(severity) + self.zooms = [ + np.arange(1, 1.11, 0.01), + np.arange(1, 1.16, 0.01), + np.arange(1, 1.21, 0.02), + np.arange(1, 1.26, 0.02), + np.arange(1, 1.31, 0.03), + ][severity - 1] + + if not scipy_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the all option:" + """pip install -U "torch_uncertainty[all]".""" ) - for d in range(3) - ] - return torch.clip(torch.stack(channels), 0, 1) - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + img = img.permute(1, 2, 0).numpy() + out = np.zeros_like(img) + for zoom_factor in self.zooms: + out += clipped_zoom(img, zoom_factor) + img = (img + out) / (len(self.zooms) + 1) + return torch.clamp(torch.as_tensor(img).permute(2, 0, 1), 0, 1) -class JPEGCompression(nn.Module): # TODO: batch +class Snow(TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity - self.quality = [80, 65, 58, 50, 40][severity - 1] + super().__init__(severity) + self.mix = [ + (0.1, 0.3, 3, 0.5, 10, 4, 0.8), + (0.2, 0.3, 2, 0.5, 12, 4, 0.7), + (0.55, 0.3, 4, 0.9, 12, 8, 0.7), + (0.55, 0.3, 4.5, 0.85, 12, 8, 0.65), + (0.55, 0.3, 2.5, 0.85, 12, 12, 0.55), + ][severity - 1] + self.rng = np.random.default_rng() + + if not wand_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img + _, height, width = img.shape + x = img.numpy() + snow_layer = self.rng.normal( + size=x.shape[1:], loc=self.mix[0], scale=self.mix[1] + )[..., np.newaxis] + snow_layer = clipped_zoom(snow_layer, self.mix[2]) + snow_layer[snow_layer < self.mix[3]] = 0 + snow_layer = Image.fromarray( + (np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), + mode="L", + ) output = BytesIO() - ToPILImage()(img).save(output, "JPEG", quality=self.quality) - return ToTensor()(Image.open(output)) + snow_layer.save(output, format="PNG") + snow_layer = MotionImage(blob=output.getvalue()) + snow_layer.motion_blur( + radius=self.mix[4], + sigma=self.mix[5], + angle=self.rng.uniform(-135, -45), + ) + snow_layer = ( + cv2.imdecode( + np.fromstring(snow_layer.make_blob(), np.uint8), + cv2.IMREAD_UNCHANGED, + ) + / 255.0 + ) + snow_layer = snow_layer[np.newaxis, ...] + x = self.mix[6] * x + (1 - self.mix[6]) * np.maximum( + x, + cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape( + 1, height, width + ) + * 1.5 + + 0.5, + ) + return torch.clamp( + torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1 + ) - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" +class Frost(TUCorruption): + def __init__(self, severity: int) -> None: + super().__init__(severity) + self.rng = np.random.default_rng() + self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][ + severity - 1 + ] + self.frost_ds = FrostImages( + "./data", download=True, transform=ToTensor() + ) -class Pixelate(nn.Module): # TODO: batch + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + _, height, width = img.shape + frost_img = RandomResizedCrop((height, width))( + self.frost_ds[self.rng.integers(low=0, high=4)] + ) + return torch.clamp(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) + + +def plasma_fractal(height, width, wibbledecay=3): + """Generate a heightmap using diamond-square algorithm. + Return square 2d array, side length 'mapsize', of floats in range 0-1. + 'mapsize' must be a power of two. + """ + maparray = np.empty((height, width), dtype=np.float64) + maparray[0, 0] = 0 + stepsize = height + wibble = 100 + rng = np.random.default_rng() + + def wibbledmean(array): + return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape) + + def fillsquares(): + """For each square of points stepsize apart, calculate middle value as mean of points + wibble.""" + cornerref = maparray[0:height:stepsize, 0:height:stepsize] + squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0) + squareaccum += np.roll(squareaccum, shift=-1, axis=1) + maparray[ + stepsize // 2 : height : stepsize, + stepsize // 2 : height : stepsize, + ] = wibbledmean(squareaccum) + + def filldiamonds(): + """For each diamond of points stepsize apart, calculate middle value as mean of points + wibble.""" + mapsize = maparray.shape[0] + drgrid = maparray[ + stepsize // 2 : mapsize : stepsize, + stepsize // 2 : mapsize : stepsize, + ] + ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] + ldrsum = drgrid + np.roll(drgrid, 1, axis=0) + lulsum = ulgrid + np.roll(ulgrid, -1, axis=1) + ltsum = ldrsum + lulsum + maparray[0:mapsize:stepsize, stepsize // 2 : mapsize : stepsize] = ( + wibbledmean(ltsum) + ) + tdrsum = drgrid + np.roll(drgrid, 1, axis=1) + tulsum = ulgrid + np.roll(ulgrid, -1, axis=0) + ttsum = tdrsum + tulsum + maparray[stepsize // 2 : mapsize : stepsize, 0:mapsize:stepsize] = ( + wibbledmean(ttsum) + ) + + while stepsize >= 2: + fillsquares() + filldiamonds() + stepsize //= 2 + wibble /= wibbledecay + + maparray -= maparray.min() + return maparray / maparray.max() + + +class Fog(TUCorruption): + def __init__(self, severity: int, size: int = 256) -> None: + super().__init__(severity) + if (size & (size - 1) == 0) and size != 0: + self.size = size + self.resize = Resize((size, size), InterpolationMode.BICUBIC) + else: + raise ValueError(f"Size must be a power of 2. Got {size}.") + self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][ + severity - 1 + ] + + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + _, height, width = img.shape + if height != width: + raise ValueError(f"Image must be square. Got {height}x{width}.") + img = self.resize(img) + max_val = img.max() + fog = ( + self.mix[0] + * plasma_fractal( + height=height, width=width, wibbledecay=self.mix[1] + )[:height, :width] + ) + final = torch.clamp( + (img + fog) * max_val / (max_val + self.mix[0]), 0, 1 + ) + return Resize((height, width), InterpolationMode.BICUBIC)(final) + + +class Brightness(IBrightness, TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity + TUCorruption.__init__(self, severity) + self.level = [1.1, 1.2, 1.3, 1.4, 1.5][severity - 1] + + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + return IBrightness.forward(self, img, self.level) + + +class Contrast(IContrast, TUCorruption): + def __init__(self, severity: int) -> None: + TUCorruption.__init__(self, severity) + self.level = [0.4, 0.3, 0.2, 0.1, 0.05][severity - 1] + + def forward(self, img: Tensor) -> Tensor | Image.Image: + if self.severity == 0: + return img + return IContrast.forward(self, img, self.level) + + +class Pixelate(TUCorruption): + def __init__(self, severity: int) -> None: + super().__init__(severity) self.quality = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1] def forward(self, img: Tensor) -> Tensor: @@ -305,33 +526,174 @@ def forward(self, img: Tensor) -> Tensor: )(img) return ToTensor()(Resize((height, width), InterpolationMode.BOX)(img)) - def __repr__(self) -> str: - """Printable representation.""" - return self.__class__.__name__ + f"(severity={self.severity})" +class JPEGCompression(TUCorruption): + def __init__(self, severity: int) -> None: + super().__init__(severity) + self.quality = [80, 65, 58, 50, 40][severity - 1] -class Frost(nn.Module): + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + output = BytesIO() + ToPILImage()(img).save(output, "JPEG", quality=self.quality) + return ToTensor()(Image.open(output)) + + +class Elastic(TUCorruption): def __init__(self, severity: int) -> None: - super().__init__() + super().__init__(severity) + if not cv2_installed or not scipy_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the all option:" + """pip install -U "torch_uncertainty[all]".""" + ) + # The following pertubation values are based on the original repo but + # are quite strange, notably for the severities 3 and 4 + self.mix = [ + (2, 0.7, 0.1), + (2, 0.08, 0.2), + (0.05, 0.01, 0.02), + (0.07, 0.01, 0.02), + (0.12, 0.01, 0.02), + ][severity - 1] self.rng = np.random.default_rng() - if not (0 <= severity <= 5): - raise ValueError("Severity must be between 0 and 5.") - if not isinstance(severity, int): - raise TypeError("Severity must be an integer.") - self.severity = severity - self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][ - severity - 1 - ] - self.frost_ds = FrostImages( - "./data", download=True, transform=ToTensor() + + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + image = np.array(img.permute(1, 2, 0), dtype=np.float32) + shape = image.shape + shape_size = shape[:2] + + # random affine + center_square = np.float32(shape_size) // 2 + square_size = min(shape_size) // 3 + pts1 = np.float32( + [ + center_square + square_size, + [ + center_square[0] + square_size, + center_square[1] - square_size, + ], + center_square - square_size, + ] + ) + pts2 = pts1 + self.rng.uniform( + -self.mix[2] * shape_size[0], + self.mix[2] * shape_size[0], + size=pts1.shape, + ).astype(np.float32) + affine_transform = cv2.getAffineTransform(pts1, pts2) + image = cv2.warpAffine( + image, + affine_transform, + shape_size[::-1], + borderMode=cv2.BORDER_REFLECT_101, ) + dx = ( + gaussian( + self.rng.uniform(-1, 1, size=shape[:2]), + self.mix[1] * shape_size[0], + mode="reflect", + truncate=3, + ) + * self.mix[0] + * shape_size[0] + ).astype(np.float32) + dy = ( + gaussian( + self.rng.uniform(-1, 1, size=shape[:2]), + self.mix[1] * shape_size[0], + mode="reflect", + truncate=3, + ) + * self.mix[0] + * shape_size[0] + ).astype(np.float32) + dx, dy = dx[..., np.newaxis], dy[..., np.newaxis] + + x, y, z = np.meshgrid( + np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]) + ) + indices = ( + np.reshape(y + dy, (-1, 1)), + np.reshape(x + dx, (-1, 1)), + np.reshape(z, (-1, 1)), + ) + img = np.clip( + map_coordinates(image, indices, order=1, mode="reflect").reshape( + shape + ), + 0, + 1, + ) + return torch.as_tensor(img).permute(2, 0, 1) + + +class SpeckleNoise(TUCorruption): + def __init__(self, severity: int) -> None: + super().__init__(severity) + self.scale = [0.06, 0.1, 0.12, 0.16, 0.2][severity - 1] + self.rng = np.random.default_rng() + def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - _, height, width = img.shape - frost_img = RandomResizedCrop((height, width))( - self.frost_ds[self.rng.integers(low=0, high=4)] + return torch.clamp( + img + img * self.rng.normal(img, self.scale), + 0, + 1, + ) + + +class GaussianBlur(TUCorruption): + def __init__(self, severity: int) -> None: + super().__init__(severity) + if not skimage_installed: # coverage: ignore + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) + self.sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1] + + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + return torch.clamp( + torch.as_tensor(gaussian(img, sigma=self.sigma)), + min=0, + max=1, ) - return torch.clip(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) + +class Saturation(ISaturation, TUCorruption): + def __init__(self, severity: int) -> None: + TUCorruption.__init__(self, severity) + self.severity = severity + self.level = [0.1, 0.2, 0.3, 0.4, 0.5][severity - 1] + + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + return ISaturation.forward(self, img, self.level) + + +corruption_transforms = ( + GaussianNoise, + ShotNoise, + ImpulseNoise, + DefocusBlur, + GlassBlur, + MotionBlur, + ZoomBlur, + Snow, + Frost, + Fog, + Brightness, + Contrast, + Elastic, + Pixelate, + JPEGCompression, +) diff --git a/torch_uncertainty/transforms/cutout.py b/torch_uncertainty/transforms/cutout.py index 9a91215f..3af0e477 100644 --- a/torch_uncertainty/transforms/cutout.py +++ b/torch_uncertainty/transforms/cutout.py @@ -1,4 +1,3 @@ -import numpy as np import torch from torch import nn @@ -12,31 +11,29 @@ def __init__(self, length: int, value: int = 0) -> None: value (int): Pixel value to be filled in the cutout square. """ super().__init__() - self.rng = np.random.default_rng() if length <= 0: - raise ValueError("Cutout length must be positive.") + raise ValueError(f"Cutout length must be positive. Got {length}.") self.length = length if value < 0 or value > 255: - raise ValueError("Cutout value must be between 0 and 255.") + raise ValueError( + f"Cutout value must be between 0 and 255. Got {value}." + ) self.value = value def __call__(self, img: torch.Tensor) -> torch.Tensor: if len(img.shape) == 2: img = img.unsqueeze(0) h, w = img.size(1), img.size(2) - mask = np.ones((h, w), np.float32) - y = self.rng.integers(low=0, high=h - 1) - x = self.rng.integers(low=0, high=w - 1) + mask = torch.ones(size=(h, w), dtype=torch.float32) + y = torch.randint(high=h, size=(1,)) + x = torch.randint(high=w, size=(1,)) - y1 = np.clip(y - self.length // 2, 0, h) - y2 = np.clip(y + self.length // 2, 0, h) - x1 = np.clip(x - self.length // 2, 0, w) - x2 = np.clip(x + self.length // 2, 0, w) + y1 = torch.clip(y - self.length // 2, 0, h) + y2 = torch.clip(y + self.length // 2, 0, h) + x1 = torch.clip(x - self.length // 2, 0, w) + x2 = torch.clip(x + self.length // 2, 0, w) mask[y1:y2, x1:x2] = self.value - mask = torch.from_numpy(mask) - mask = mask.expand_as(img) - img *= mask - return img + return img * mask.expand_as(img) diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index 2e8707eb..ba1eb838 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -68,8 +68,8 @@ def __init__( random_direction: bool = True, interpolation: F.InterpolationMode = F.InterpolationMode.NEAREST, expand: bool = False, - center: list[int] | None = None, - fill: list[int] | None = None, + center: list[float] | None = None, + fill: list[float] | None = None, ) -> None: super().__init__() self.random_direction = random_direction @@ -105,8 +105,8 @@ def __init__( axis: int, random_direction: bool = True, interpolation: F.InterpolationMode = F.InterpolationMode.NEAREST, - center: list[int] | None = None, - fill: list[int] | None = None, + center: list[float] | None = None, + fill: list[float] | None = None, ) -> None: super().__init__() if axis not in (0, 1): @@ -124,7 +124,7 @@ def forward( self.random_direction and torch.rand(1).item() > 0.5 ): # coverage: ignore level = -level - shear = [0, 0] + shear = [0.0, 0.0] shear[self.axis] = level return F.affine( img, @@ -148,8 +148,8 @@ def __init__( axis: int, random_direction: bool = True, interpolation: F.InterpolationMode = F.InterpolationMode.NEAREST, - center: list[int] | None = None, - fill: list[int] | None = None, + center: list[float] | None = None, + fill: list[float] | None = None, ) -> None: super().__init__() if axis not in (0, 1): @@ -161,13 +161,13 @@ def __init__( self.fill = fill def forward( - self, img: Tensor | Image.Image, level: int + self, img: Tensor | Image.Image, level: float ) -> Tensor | Image.Image: if ( self.random_direction and torch.rand(1).item() > 0.5 ): # coverage: ignore level = -level - translate = [0, 0] + translate = [0.0, 0.0] translate[self.axis] = level return F.affine( img, @@ -205,14 +205,25 @@ class Brightness(nn.Module): def __init__(self) -> None: super().__init__() - def forward( - self, img: Tensor | Image.Image, level: float - ) -> Tensor | Image.Image: + def forward(self, img: Tensor, level: float) -> Tensor: if level < 0: raise ValueError("Level must be greater than 0.") return F.adjust_brightness(img, level) +class Saturation(nn.Module): + level_type = float + corruption_overlap = True + + def __init__(self) -> None: + super().__init__() + + def forward(self, img: Tensor, level: float) -> Tensor: + if level < 0: + raise ValueError("Level must be greater than 0.") + return F.adjust_saturation_image(img, level) + + class Sharpen(nn.Module): pixmix_max_level = 1.8 level_type = float @@ -243,9 +254,11 @@ def forward( ) -> Tensor | Image.Image: if level < 0: raise ValueError("Level must be greater than 0.") + pil_img = F.to_pil_image(img) if isinstance(img, Tensor) else img + pil_img = ImageEnhance.Color(pil_img).enhance(level) if isinstance(img, Tensor): - img: Image.Image = F.to_pil_image(img) - return ImageEnhance.Color(img).enhance(level) + return F.pil_to_tensor(pil_img) + return pil_img class RandomRescale(Transform): diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index bd26534d..43c4283d 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -1,9 +1,17 @@ +from importlib import util + import numpy as np -import scipy import torch import torch.nn.functional as F from torch import Tensor, nn +if util.find_spec("scipy"): + import scipy + + scipy_installed = True +else: # coverage: ignore + scipy_installed = False + def beta_warping(x, alpha_cdf: float = 1.0, eps: float = 1e-12) -> float: return scipy.stats.beta.cdf(x, a=alpha_cdf + eps, b=alpha_cdf + eps) @@ -197,7 +205,13 @@ def __init__( "Tailoring Mixup to Data using Kernel Warping functions" (2023) https://arxiv.org/abs/2311.01434. """ - super().__init__(alpha, mode, num_classes) + if not scipy_installed: # coverage: ignore + raise ImportError( + "The scipy library is not installed. Please install" + "torch_uncertainty with the all option:" + """pip install -U "torch_uncertainty[all]".""" + ) + super().__init__(alpha=alpha, mode=mode, num_classes=num_classes) self.apply_kernel = apply_kernel self.tau_max = tau_max self.tau_std = tau_std diff --git a/torch_uncertainty/utils/cli.py b/torch_uncertainty/utils/cli.py index 8b8659c4..12a0f9b4 100644 --- a/torch_uncertainty/utils/cli.py +++ b/torch_uncertainty/utils/cli.py @@ -130,3 +130,8 @@ def add_default_arguments_to_parser( default=self.eval_after_fit_default, ) super().add_default_arguments_to_parser(parser) + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) + parser.link_arguments("data.eval_ood", "model.eval_ood") + parser.link_arguments("data.eval_shift", "model.eval_shift") diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index 3872ee6d..5bc729c7 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -15,6 +15,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: # test/cls: Classification Metrics # test/cal: Calibration Metrics # ood: OOD Detection Metrics + # shift: Distribution shift Metrics # test/sc: Selective Classification Metrics # test/post: Post-Processing Metrics # test/seg: Segmentation Metrics @@ -29,6 +30,9 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: "Risk@80Cov", "pixAcc", "mIoU", + "AURC", + "AUGRC", + "mAcc", ] metrics = {} @@ -49,6 +53,11 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: metrics["ood"] = {} metric_name = key.split("/")[-1] metrics["ood"].update({metric_name: value}) + elif key.startswith("shift"): + if "shift" not in metrics: + metrics["shift"] = {} + metric_name = key.split("/")[-1] + metrics["shift"].update({metric_name: value}) elif key.startswith("test/sc"): if "sc" not in metrics: metrics["sc"] = {} @@ -196,6 +205,29 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: table.add_row(metric, f"{value.item():.5f}") tables.append(table) + if "shift" in metrics: + table = Table() + table.add_column( + first_col_name, justify="center", style="cyan", width=12 + ) + shift_severity = int(metrics["shift"]["shift_severity"]) + table.add_column( + f"Distribution Shift lvl{shift_severity}", + justify="center", + style="magenta", + width=25, + ) + shift_metrics = OrderedDict(sorted(metrics["shift"].items())) + for metric, value in shift_metrics.items(): + if metric == "shift_severity": + continue + if metric in percentage_metrics: + value = value * 100 + table.add_row(metric, f"{value.item():.2f}%") + else: + table.add_row(metric, f"{value.item():.5f}") + tables.append(table) + console = get_console() group = Group(*tables) console.print(group) diff --git a/torch_uncertainty/utils/plotting.py b/torch_uncertainty/utils/plotting.py new file mode 100644 index 00000000..0e5fa482 --- /dev/null +++ b/torch_uncertainty/utils/plotting.py @@ -0,0 +1,19 @@ +import matplotlib.pyplot as plt +import numpy as np +import torchvision.transforms.functional as F +from torch import Tensor + + +def show(prediction: Tensor, target: Tensor): + imgs = [prediction, target] + fig, axs = plt.subplots(ncols=len(imgs), figsize=(12, 6)) + for i, img in enumerate(imgs): + img = img.detach() + img = F.to_pil_image(img) + axs[i].imshow(np.asarray(img)) + axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + + axs[0].set(title="Prediction") + axs[1].set(title="Ground Truth") + + return fig