From 2b9628523b636bf4f86f9ae34f8e9d03a7f7dec8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 25 Oct 2023 19:59:29 +0200 Subject: [PATCH] :shirt: Various small improvements --- torch_uncertainty/baselines/classification/wideresnet.py | 1 + torch_uncertainty/datasets/classification/cifar/cifar_c.py | 2 +- torch_uncertainty/datasets/classification/mnist_c.py | 2 +- torch_uncertainty/datasets/classification/not_mnist.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index e0a8e814..fac829c3 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -96,6 +96,7 @@ class WideResNet: LightningModule: Wide-ResNet baseline ready for training and evaluation. """ + single = ["vanilla"] ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] versions = { diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index 137e5380..ac1848fd 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -111,7 +111,7 @@ def __init__( transform=transform, target_transform=target_transform, ) - if not (subset in ["all"] + self.cifarc_subsets): + if subset not in ["all"] + self.cifarc_subsets: raise ValueError( f"The subset '{subset}' does not exist in CIFAR-C." ) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index d5119db7..fea414ef 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -93,7 +93,7 @@ def __init__( transform=transform, target_transform=target_transform, ) - if not (subset in ["all"] + self.mnistc_subsets): + if subset not in ["all"] + self.mnistc_subsets: raise ValueError( f"The subset '{subset}' does not exist in MNIST-C." ) diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index 37cb179d..2a21b802 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -49,7 +49,7 @@ def __init__( if isinstance(root, str): self.root = Path(root) - if not (subset in self.subsets): + if subset not in self.subsets: raise ValueError( f"The subset '{subset}' does not exist for notMNIST." )