Skip to content

Commit

Permalink
Merge pull request #57 from FR-DC/FRML-120
Browse files Browse the repository at this point in the history
FRML-120 Implement EfficientNetB1
  • Loading branch information
Eve-ning authored Feb 21, 2024
2 parents 6cdfb4f + b577933 commit f5724d4
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 45 deletions.
136 changes: 136 additions & 0 deletions src/frdc/models/efficientnetb1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from copy import deepcopy

import torch
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import nn
from torchvision.models import (
EfficientNet,
efficientnet_b1,
EfficientNet_B1_Weights,
)

from frdc.train.mixmatch_module import MixMatchModule
from frdc.utils.ema import EMA


class EfficientNetB1MixMatchModule(MixMatchModule):
MIN_SIZE = 320
EFF_OUT_DIMS = 1280

def __init__(
self,
*,
in_channels: int,
n_classes: int,
lr: float,
x_scaler: StandardScaler,
y_encoder: OrdinalEncoder,
ema_lr: float = 0.001,
weight_decay: float = 1e-5,
frozen: bool = True,
):
"""Initialize the EfficientNet model.
Args:
in_channels: The number of input channels.
n_classes: The number of classes.
lr: The learning rate.
x_scaler: The X input StandardScaler.
y_encoder: The Y input OrdinalEncoder.
ema_lr: The learning rate for the EMA model.
weight_decay: The weight decay.
frozen: Whether to freeze the base model.
Notes:
- Min input size: 320 x 320
"""
self.lr = lr
self.weight_decay = weight_decay

super().__init__(
n_classes=n_classes,
x_scaler=x_scaler,
y_encoder=y_encoder,
sharpen_temp=0.5,
mix_beta_alpha=0.75,
)

self.eff = efficientnet_b1(
weights=EfficientNet_B1_Weights.IMAGENET1K_V2
)

# Remove the final layer
self.eff.classifier = nn.Identity()

if frozen:
for param in self.eff.parameters():
param.requires_grad = False

# Adapt the first layer to accept the number of channels
self.eff = self.adapt_efficient_multi_channel(self.eff, in_channels)

self.fc = nn.Sequential(
nn.Linear(self.EFF_OUT_DIMS, n_classes),
nn.Softmax(dim=1),
)

# The problem is that the deep copy runs even before the module is
# initialized, which means ema_model is empty.
ema_model = deepcopy(self)
for param in ema_model.parameters():
param.detach_()

self._ema_model = ema_model
self.ema_updater = EMA(model=self, ema_model=self.ema_model)
self.ema_lr = ema_lr

@staticmethod
def adapt_efficient_multi_channel(
eff: EfficientNet,
in_channels: int,
) -> EfficientNet:
"""Adapt the EfficientNet model to accept a different number of
input channels.
Notes:
This operation is in-place, however will still return the model
Args:
eff: The EfficientNet model
in_channels: The number of input channels
Returns:
The adapted EfficientNet model.
"""
old_conv = eff.features[0][0]
new_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias,
)
new_conv.weight.data[:, :3] = old_conv.weight.data
new_conv.weight.data[:, 3:] = old_conv.weight.data[:, 1:2].repeat(
1, 5, 1, 1
)
eff.features[0][0] = new_conv

return eff

@property
def ema_model(self):
return self._ema_model

def update_ema(self):
self.ema_updater.update(self.ema_lr)

def forward(self, x: torch.Tensor):
"""Forward pass."""
return self.fc(self.eff(x))

def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
6 changes: 2 additions & 4 deletions src/frdc/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,11 @@ def plot_confusion_matrix(

def predict(
ds: FRDCDataset,
model_cls: type[pl.LightningModule],
ckpt_pth: Path | str | None = None,
model: pl.LightningModule,
) -> tuple[np.ndarray, np.ndarray]:
m = model_cls.load_from_checkpoint(ckpt_pth)
# Make predictions
trainer = pl.Trainer(logger=False)
pred = trainer.predict(m, dataloaders=DataLoader(ds, batch_size=32))
pred = trainer.predict(model, dataloaders=DataLoader(ds, batch_size=32))

y_preds = []
y_trues = []
Expand Down
22 changes: 12 additions & 10 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.preprocessing import StandardScaler, OrdinalEncoder

from frdc.load.preset import FRDCDatasetPreset as ds
from frdc.models.inceptionv3 import InceptionV3MixMatchModule
from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule
from frdc.train.frdc_datamodule import FRDCDataModule
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
Expand Down Expand Up @@ -61,11 +61,14 @@ def main(
wandb_project="frdc",
):
# Prepare the dataset
train_lab_ds = ds.chestnut_20201218(transform=train_preprocess_augment)
im_size = 299
train_lab_ds = ds.chestnut_20201218(
transform=train_preprocess_augment(im_size)
)
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2)
transform=train_unl_preprocess(im_size, 2)
)
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess)
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size))

# Prepare the datamodule and trainer
dm = FRDCDataModule(
Expand Down Expand Up @@ -103,13 +106,13 @@ def main(
oe = get_y_encoder(train_lab_ds.targets)
ss = get_x_scaler(train_lab_ds.ar_segments)

m = InceptionV3MixMatchModule(
m = EfficientNetB1MixMatchModule(
in_channels=train_lab_ds.ar.shape[-1],
n_classes=len(oe.categories_[0]),
lr=lr,
x_scaler=ss,
y_encoder=oe,
imagenet_scaling=True,
frozen=True,
)

trainer.fit(m, datamodule=dm)
Expand All @@ -125,10 +128,9 @@ def main(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
transform=val_preprocess,
transform=val_preprocess(im_size),
),
model_cls=InceptionV3MixMatchModule,
ckpt_pth=Path(ckpt.best_model_path),
model=m,
)
fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0])
acc = np.sum(y_true == y_pred) / len(y_true)
Expand All @@ -152,6 +154,6 @@ def main(
epochs=EPOCHS,
train_iters=TRAIN_ITERS,
lr=LR,
wandb_name="Try with Inception Unfrozen & Random Erasing",
wandb_name="EfficientNet 299x299",
wandb_project="frdc-dev",
)
45 changes: 14 additions & 31 deletions tests/model_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
RandomRotation,
RandomApply,
Resize,
RandomErasing,
)
from torchvision.transforms.v2 import RandomHorizontalFlip

from frdc.load.dataset import FRDCDataset
from frdc.models.inceptionv3 import InceptionV3MixMatchModule

THIS_DIR = Path(__file__).parent

Expand Down Expand Up @@ -49,49 +47,34 @@ def __getitem__(self, idx):
return RandomHorizontalFlip(p=1)(RandomVerticalFlip(p=1)(x)), y


def val_preprocess(x):
return Compose(
def val_preprocess(size: int):
return lambda x: Compose(
[
ToImage(),
ToDtype(torch.float32, scale=True),
Resize(
InceptionV3MixMatchModule.MIN_SIZE,
antialias=True,
),
CenterCrop(
InceptionV3MixMatchModule.MIN_SIZE,
),
Resize(size, antialias=True),
CenterCrop(size),
]
)(x)


def train_preprocess_augment(x):
return Compose(
def train_preprocess_augment(size: int):
return lambda x: Compose(
[
ToImage(),
ToDtype(torch.float32, scale=True),
Resize(
InceptionV3MixMatchModule.MIN_SIZE,
antialias=True,
),
RandomCrop(
InceptionV3MixMatchModule.MIN_SIZE,
pad_if_needed=False,
),
Resize(size, antialias=True),
RandomCrop(size, pad_if_needed=False),
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomApply([RandomRotation((90, 90))], p=0.5),
]
)(x)


def train_unl_preprocess(n_aug: int = 2):
def f(x):
# This simulates the n_aug of MixMatch
return (
[train_preprocess_augment(x) for _ in range(n_aug)]
if n_aug > 0
else None
)

return f
def train_unl_preprocess(size, n_aug: int = 2):
return lambda x: (
[train_preprocess_augment(size)(x) for _ in range(n_aug)]
if n_aug > 0
else None
)

0 comments on commit f5724d4

Please sign in to comment.