Skip to content

Commit

Permalink
Merge pull request #55 from FR-DC/FRML-119
Browse files Browse the repository at this point in the history
FRML-119 Make Standard Scaler fit on segments only
  • Loading branch information
Eve-ning authored Feb 21, 2024
2 parents 2873d2f + 9d95948 commit 6cdfb4f
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 158 deletions.
20 changes: 15 additions & 5 deletions src/frdc/models/inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ def __init__(
x_scaler: StandardScaler,
y_encoder: OrdinalEncoder,
ema_lr: float = 0.001,
imagenet_scaling: bool = False,
):
"""Initialize the InceptionV3 model.
Args:
n_classes: The number of output classes
in_channels: The number of input channels.
n_classes: The number of classes.
lr: The learning rate.
x_scaler: The X input StandardScaler.
y_encoder: The Y input OrdinalEncoder.
ema_lr: The learning rate for the EMA model.
imagenet_scaling: Whether to use the adapted ImageNet scaling.
Notes:
- Min input size: 299 x 299.
Expand All @@ -46,6 +53,7 @@ def __init__(
sharpen_temp=0.5,
mix_beta_alpha=0.75,
)
self.imagenet_scaling = imagenet_scaling

self.inception = inception_v3(
weights=Inception_V3_Weights.IMAGENET1K_V1,
Expand Down Expand Up @@ -74,8 +82,8 @@ def __init__(
# The problem is that the deep copy runs even before the module is
# initialized, which means ema_model is empty.
ema_model = deepcopy(self)
for param in ema_model.parameters():
param.detach_()
# for param in ema_model.parameters():
# param.detach_()

self._ema_model = ema_model
self.ema_updater = EMA(model=self, ema_model=self.ema_model)
Expand Down Expand Up @@ -129,7 +137,7 @@ def adapt_inception_multi_channel(
return inception

@staticmethod
def transform_input(x: torch.Tensor) -> torch.Tensor:
def _imagenet_scaling(x: torch.Tensor) -> torch.Tensor:
"""Perform adapted ImageNet normalization on the input tensor.
See Also:
Expand Down Expand Up @@ -181,7 +189,9 @@ def forward(self, x: torch.Tensor):
f"Got: {x.shape[2]} x {x.shape[3]}."
)

x = self.transform_input(x)
if self.imagenet_scaling:
x = self._imagenet_scaling(x)

# During training, the auxiliary outputs are used for auxiliary loss,
# but during testing, only the main output is used.
if self.training:
Expand Down
2 changes: 0 additions & 2 deletions src/frdc/train/frdc_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class FRDCDataModule(LightningDataModule):
batch_size: The batch size to use for the dataloaders.
train_iters: The number of iterations to run for the labelled training
dataset.
val_iters: The number of iterations to run for the validation dataset.
"""

Expand All @@ -62,7 +61,6 @@ class FRDCDataModule(LightningDataModule):
train_unl_ds: FRDCDataset | FRDCUnlabelledDataset | None = None
batch_size: int = 4
train_iters: int = 100
val_iters: int = 100
sampling_strategy: Literal["stratified", "random"] = "stratified"

def __post_init__(self):
Expand Down
126 changes: 45 additions & 81 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
self.sharpen_temp = sharpen_temp
self.mix_beta_alpha = mix_beta_alpha
self.save_hyperparameters()
self.lbl_logger = WandBLabelLogger()

@property
@abstractmethod
Expand Down Expand Up @@ -151,20 +150,19 @@ def progress(self):

def training_step(self, batch, batch_idx):
(x_lbl, y_lbl), x_unls = batch
self.lbl_logger(
self.logger.experiment,
"Input Y Label",
y_lbl,
flush_every=10,
num_bins=self.n_classes,
)

self.log("train/x_lbl_mean", x_lbl.mean())
self.log("train/x_lbl_stdev", x_lbl.std())

wandb.log({"train/x_lbl": self.wandb_hist(y_lbl, self.n_classes)})
y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes)

# If x_unls is Truthy, then we are using MixMatch.
# Otherwise, we are just using supervised learning.
if x_unls:
# This route implies that we are using SSL
self.log("train/x0_unl_mean", x_unls[0].mean())
self.log("train/x0_unl_stdev", x_unls[0].std())
with torch.no_grad():
y_unl = self.guess_labels(x_unls=x_unls)
y_unl = self.sharpen(y_unl, self.sharpen_temp)
Expand All @@ -183,42 +181,42 @@ def training_step(self, batch, batch_idx):
y_mix_unl = y_mix[batch_size:]

loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl)
self.lbl_logger(
self.logger.experiment,
"Labelled Y Pred",
torch.argmax(y_mix_lbl_pred, dim=1),
flush_every=10,
num_bins=self.n_classes,
)
loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl)
self.lbl_logger(
self.logger.experiment,
"Unlabelled Y Pred",
torch.argmax(y_mix_unl_pred, dim=1),
flush_every=10,
num_bins=self.n_classes,
wandb.log(
{
"train/y_lbl_pred": self.wandb_hist(
torch.argmax(y_mix_lbl_pred, dim=1), self.n_classes
)
}
)
wandb.log(
{
"train/y_unl_pred": self.wandb_hist(
torch.argmax(y_mix_unl_pred, dim=1), self.n_classes
)
}
)
loss_unl_scale = self.loss_unl_scaler(progress=self.progress)

loss = loss_lbl + loss_unl * loss_unl_scale

self.log("loss_unl_scale", loss_unl_scale, prog_bar=True)
self.log("train_loss_lbl", loss_lbl)
self.log("train_loss_unl", loss_unl)
self.log("train/loss_unl_scale", loss_unl_scale, prog_bar=True)
self.log("train/ce_loss_lbl", loss_lbl)
self.log("train/mse_loss_unl", loss_unl)
else:
# This route implies that we are just using supervised learning
y_pred = self(x_lbl)
loss = self.loss_lbl(y_pred, y_lbl_ohe.float())

self.log("train_loss", loss)
self.log("train/loss", loss)

# Evaluate train accuracy
with torch.no_grad():
y_pred = self.ema_model(x_lbl)
acc = accuracy(
y_pred, y_lbl, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("train_acc", acc, prog_bar=True)
self.log("train/acc", acc, prog_bar=True)
return loss

# PyTorch Lightning doesn't automatically no_grads the EMA step.
Expand All @@ -227,30 +225,31 @@ def training_step(self, batch, batch_idx):
def on_after_backward(self) -> None:
self.update_ema()

@staticmethod
def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram:
return wandb.Histogram(
torch.flatten(x).detach().cpu().tolist(),
num_bins=num_bins,
)

def validation_step(self, batch, batch_idx):
x, y = batch
self.lbl_logger(
self.logger.experiment,
"Val Input Y Label",
y,
flush_every=1,
num_bins=self.n_classes,
)
wandb.log({"val/y_lbl": self.wandb_hist(y, self.n_classes)})
y_pred = self.ema_model(x)
self.lbl_logger(
self.logger.experiment,
"Val Pred Y Label",
torch.argmax(y_pred, dim=1),
flush_every=1,
num_bins=self.n_classes,
wandb.log(
{
"val/y_lbl_pred": self.wandb_hist(
torch.argmax(y_pred, dim=1), self.n_classes
)
}
)
loss = F.cross_entropy(y_pred, y.long())

acc = accuracy(
y_pred, y, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("val_loss", loss)
self.log("val_acc", acc, prog_bar=True)
self.log("val/ce_loss", loss)
self.log("val/acc", acc, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
Expand All @@ -261,8 +260,8 @@ def test_step(self, batch, batch_idx):
acc = accuracy(
y_pred, y, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("test_loss", loss)
self.log("test_acc", acc, prog_bar=True)
self.log("test/ce_loss", loss)
self.log("test/acc", acc, prog_bar=True)
return loss

def predict_step(self, batch, *args, **kwargs) -> Any:
Expand Down Expand Up @@ -305,7 +304,7 @@ def x_trans_fn(x):

# Move Channel back to the second dimension
# B x H x W x C -> B x C x H x W
return (
return torch.nan_to_num(
torch.from_numpy(x_ss.reshape(b, h, w, c))
.permute(0, 3, 1, 2)
.float()
Expand Down Expand Up @@ -335,46 +334,11 @@ def y_trans_fn(y):
nan = ~torch.isnan(y_trans)
x_lab_trans = x_lab_trans[nan]
x_unl_trans = [x[nan] for x in x_unl_trans]
x_lab_trans = torch.nan_to_num(x_lab_trans)
x_unl_trans = [torch.nan_to_num(x) for x in x_unl_trans]
y_trans = y_trans[nan]

if self.training:
return (x_lab_trans, y_trans.long()), x_unl_trans
else:
return x_lab_trans, y_trans.long()


class WandBLabelLogger(dict):
"""Logger to log y labels to WandB"""

def __call__(
self,
logger: wandb.sdk.wandb_run.Run,
key: str,
value: torch.Tensor,
num_bins: int,
flush_every: int = 10,
):
"""Log the labels to WandB
Args:
logger: The W&B logger. Accessible through `self.logger.experiment`
key: The key to log the labels under.
value: The labels to log.
flush_every: How often to flush the labels to WandB.
"""
if key not in self.keys():
self[key] = [value]
else:
self[key].append(value)

if len(self[key]) % flush_every == 0:
logger.log(
{
key: wandb.Histogram(
torch.flatten(value).detach().cpu().tolist(),
num_bins=num_bins,
)
}
)
self[key] = []
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import pytest
import wandb

from frdc.load.dataset import FRDCDataset
from frdc.load.preset import FRDCDatasetPreset

wandb.init(mode="disabled")


@pytest.fixture(scope="session")
def ds() -> FRDCDataset:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def test_manual_segmentation_pipeline(ds):
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(m, datamodule=dm)

val_loss = trainer.validate(m, datamodule=dm)[0]["val_loss"]
val_loss = trainer.validate(m, datamodule=dm)[0]["val/ce_loss"]
logging.debug(f"Validation score: {val_loss:.2%}")
Loading

0 comments on commit 6cdfb4f

Please sign in to comment.