Skip to content

Commit

Permalink
✅ Improve dropout cvg.
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Oct 25, 2023
1 parent aa3c47b commit a01c46c
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions tests/baselines/test_mc_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def test_standard(self):
)
summary(net)

_ = net.criterion
_ = net.configure_optimizers()
_ = net(torch.rand(1, 3, 32, 32))
net.criterion
net.configure_optimizers()
net(torch.rand(1, 3, 32, 32))


class TestStandardWideBaseline:
Expand All @@ -49,9 +49,9 @@ def test_standard(self):
)
summary(net)

_ = net.criterion
_ = net.configure_optimizers()
_ = net(torch.rand(1, 3, 32, 32))
net.criterion
net.configure_optimizers()
net(torch.rand(1, 3, 32, 32))


class TestStandardVGGBaseline:
Expand All @@ -67,9 +67,24 @@ def test_standard(self):
num_estimators=4,
arch=11,
groups=1,
last_layer_dropout=True,
)
summary(net)

_ = net.criterion
_ = net.configure_optimizers()
_ = net(torch.rand(1, 3, 32, 32))
net.criterion
net.configure_optimizers()
net(torch.rand(1, 3, 32, 32))

net = VGG(
num_classes=10,
in_channels=3,
loss=nn.CrossEntropyLoss,
optimization_procedure=optim_cifar10_resnet18,
version="mc-dropout",
num_estimators=4,
arch=11,
groups=1,
last_layer_dropout=True,
)
net.eval()
net(torch.rand(1, 3, 32, 32))

0 comments on commit a01c46c

Please sign in to comment.