Skip to content

Commit

Permalink
✨ Add LeNet BatchEnsemble and Deep Ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton committed Mar 7, 2025
1 parent 918de33 commit c5b62d8
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 0 deletions.
67 changes: 67 additions & 0 deletions experiments/classification/mnist/configs/lenet_batch_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
fast_dev_run: false
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 10
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet_trajectory
name: batch_ensemble
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
# ClassificationRoutine
model:
# BatchEnsemble
class_path: torch_uncertainty.models.lenet.batchensemble_lenet
init_args:
in_channels: 1
num_classes: 10
num_estimators: 5
activation: torch.nn.ReLU
norm: torch.nn.BatchNorm2d
groups: 1
dropout_rate: 0
num_classes: 10
loss: CrossEntropyLoss
is_ensemble: true
format_batch_fn:
class_path: torch_uncertainty.transforms.batch.RepeatTarget
init_args:
num_repeats: 5
data:
root: ./data
batch_size: 128
num_workers: 127
eval_ood: true
eval_shift: true
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
78 changes: 78 additions & 0 deletions experiments/classification/mnist/configs/lenet_deep_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
fast_dev_run: false
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 10
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet_trajectory
name: deep_ensemble
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
# ClassificationRoutine
model:
# DeepEnsemble
class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles
init_args:
models:
# LeNet
class_path: torch_uncertainty.models.lenet._LeNet
init_args:
in_channels: 1
num_classes: 10
linear_layer: torch.nn.Linear
conv2d_layer: torch.nn.Conv2d
activation: torch.nn.ReLU
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
# last_layer_dropout: false
layer_args: {}
num_estimators: 5
task: classification
probabilistic: false
reset_model_parameters: true
num_classes: 10
loss: CrossEntropyLoss
is_ensemble: true
format_batch_fn:
class_path: torch_uncertainty.transforms.batch.RepeatTarget
init_args:
num_repeats: 5
data:
root: ./data
batch_size: 128
num_workers: 127
eval_ood: true
eval_shift: true
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
28 changes: 28 additions & 0 deletions torch_uncertainty/models/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch.nn.functional as F
from torch import nn

from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear
from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear
from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d
from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear
from torch_uncertainty.models import StochasticModel
from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble

__all__ = ["bayesian_lenet", "lenet", "packed_lenet"]

Expand Down Expand Up @@ -119,6 +121,32 @@ def lenet(
)


def batchensemble_lenet(
in_channels: int,
num_classes: int,
num_estimators: int = 4,
activation: Callable = F.relu,
norm: type[nn.Module] = nn.BatchNorm2d,
groups: int = 1,
dropout_rate: float = 0.0,
) -> _LeNet:
model = _lenet(
stochastic=False,
in_channels=in_channels,
num_classes=num_classes,
linear_layer=BatchLinear,
conv2d_layer=BatchConv2d,
layer_args={
"num_estimators": num_estimators,
},
activation=activation,
norm=norm,
groups=groups,
dropout_rate=dropout_rate,
)
return BatchEnsemble(model, num_estimators)


def packed_lenet(
in_channels: int,
num_classes: int,
Expand Down

0 comments on commit c5b62d8

Please sign in to comment.