From a01c46ccf6a918e2026b05b98500ce4a392762da Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 25 Oct 2023 19:07:31 +0200 Subject: [PATCH] :white_check_mark: Improve dropout cvg. --- tests/baselines/test_mc_dropout.py | 33 ++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/baselines/test_mc_dropout.py b/tests/baselines/test_mc_dropout.py index a793b5ff..5edae90e 100644 --- a/tests/baselines/test_mc_dropout.py +++ b/tests/baselines/test_mc_dropout.py @@ -28,9 +28,9 @@ def test_standard(self): ) summary(net) - _ = net.criterion - _ = net.configure_optimizers() - _ = net(torch.rand(1, 3, 32, 32)) + net.criterion + net.configure_optimizers() + net(torch.rand(1, 3, 32, 32)) class TestStandardWideBaseline: @@ -49,9 +49,9 @@ def test_standard(self): ) summary(net) - _ = net.criterion - _ = net.configure_optimizers() - _ = net(torch.rand(1, 3, 32, 32)) + net.criterion + net.configure_optimizers() + net(torch.rand(1, 3, 32, 32)) class TestStandardVGGBaseline: @@ -67,9 +67,24 @@ def test_standard(self): num_estimators=4, arch=11, groups=1, + last_layer_dropout=True, ) summary(net) - _ = net.criterion - _ = net.configure_optimizers() - _ = net(torch.rand(1, 3, 32, 32)) + net.criterion + net.configure_optimizers() + net(torch.rand(1, 3, 32, 32)) + + net = VGG( + num_classes=10, + in_channels=3, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + version="mc-dropout", + num_estimators=4, + arch=11, + groups=1, + last_layer_dropout=True, + ) + net.eval() + net(torch.rand(1, 3, 32, 32))