Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FRML-119 Make Standard Scaler fit on segments only #55

Merged
merged 11 commits into from
Feb 21, 2024
20 changes: 15 additions & 5 deletions src/frdc/models/inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ def __init__(
x_scaler: StandardScaler,
y_encoder: OrdinalEncoder,
ema_lr: float = 0.001,
imagenet_scaling: bool = False,
):
"""Initialize the InceptionV3 model.

Args:
n_classes: The number of output classes
in_channels: The number of input channels.
n_classes: The number of classes.
lr: The learning rate.
x_scaler: The X input StandardScaler.
y_encoder: The Y input OrdinalEncoder.
ema_lr: The learning rate for the EMA model.
imagenet_scaling: Whether to use the adapted ImageNet scaling.

Notes:
- Min input size: 299 x 299.
Expand All @@ -46,6 +53,7 @@ def __init__(
sharpen_temp=0.5,
mix_beta_alpha=0.75,
)
self.imagenet_scaling = imagenet_scaling

self.inception = inception_v3(
weights=Inception_V3_Weights.IMAGENET1K_V1,
Expand Down Expand Up @@ -74,8 +82,8 @@ def __init__(
# The problem is that the deep copy runs even before the module is
# initialized, which means ema_model is empty.
ema_model = deepcopy(self)
for param in ema_model.parameters():
param.detach_()
# for param in ema_model.parameters():
# param.detach_()

self._ema_model = ema_model
self.ema_updater = EMA(model=self, ema_model=self.ema_model)
Expand Down Expand Up @@ -129,7 +137,7 @@ def adapt_inception_multi_channel(
return inception

@staticmethod
def transform_input(x: torch.Tensor) -> torch.Tensor:
def _imagenet_scaling(x: torch.Tensor) -> torch.Tensor:
"""Perform adapted ImageNet normalization on the input tensor.

See Also:
Expand Down Expand Up @@ -181,7 +189,9 @@ def forward(self, x: torch.Tensor):
f"Got: {x.shape[2]} x {x.shape[3]}."
)

x = self.transform_input(x)
if self.imagenet_scaling:
x = self._imagenet_scaling(x)

# During training, the auxiliary outputs are used for auxiliary loss,
# but during testing, only the main output is used.
if self.training:
Expand Down
2 changes: 0 additions & 2 deletions src/frdc/train/frdc_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class FRDCDataModule(LightningDataModule):
batch_size: The batch size to use for the dataloaders.
train_iters: The number of iterations to run for the labelled training
dataset.
val_iters: The number of iterations to run for the validation dataset.

"""

Expand All @@ -62,7 +61,6 @@ class FRDCDataModule(LightningDataModule):
train_unl_ds: FRDCDataset | FRDCUnlabelledDataset | None = None
batch_size: int = 4
train_iters: int = 100
val_iters: int = 100
sampling_strategy: Literal["stratified", "random"] = "stratified"

def __post_init__(self):
Expand Down
126 changes: 45 additions & 81 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
self.sharpen_temp = sharpen_temp
self.mix_beta_alpha = mix_beta_alpha
self.save_hyperparameters()
self.lbl_logger = WandBLabelLogger()

@property
@abstractmethod
Expand Down Expand Up @@ -151,20 +150,19 @@ def progress(self):

def training_step(self, batch, batch_idx):
(x_lbl, y_lbl), x_unls = batch
self.lbl_logger(
self.logger.experiment,
"Input Y Label",
y_lbl,
flush_every=10,
num_bins=self.n_classes,
)

self.log("train/x_lbl_mean", x_lbl.mean())
self.log("train/x_lbl_stdev", x_lbl.std())

wandb.log({"train/x_lbl": self.wandb_hist(y_lbl, self.n_classes)})
y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes)

# If x_unls is Truthy, then we are using MixMatch.
# Otherwise, we are just using supervised learning.
if x_unls:
# This route implies that we are using SSL
self.log("train/x0_unl_mean", x_unls[0].mean())
self.log("train/x0_unl_stdev", x_unls[0].std())
with torch.no_grad():
y_unl = self.guess_labels(x_unls=x_unls)
y_unl = self.sharpen(y_unl, self.sharpen_temp)
Expand All @@ -183,42 +181,42 @@ def training_step(self, batch, batch_idx):
y_mix_unl = y_mix[batch_size:]

loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl)
self.lbl_logger(
self.logger.experiment,
"Labelled Y Pred",
torch.argmax(y_mix_lbl_pred, dim=1),
flush_every=10,
num_bins=self.n_classes,
)
loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl)
self.lbl_logger(
self.logger.experiment,
"Unlabelled Y Pred",
torch.argmax(y_mix_unl_pred, dim=1),
flush_every=10,
num_bins=self.n_classes,
wandb.log(
{
"train/y_lbl_pred": self.wandb_hist(
torch.argmax(y_mix_lbl_pred, dim=1), self.n_classes
)
}
)
wandb.log(
{
"train/y_unl_pred": self.wandb_hist(
torch.argmax(y_mix_unl_pred, dim=1), self.n_classes
)
}
)
loss_unl_scale = self.loss_unl_scaler(progress=self.progress)

loss = loss_lbl + loss_unl * loss_unl_scale

self.log("loss_unl_scale", loss_unl_scale, prog_bar=True)
self.log("train_loss_lbl", loss_lbl)
self.log("train_loss_unl", loss_unl)
self.log("train/loss_unl_scale", loss_unl_scale, prog_bar=True)
self.log("train/ce_loss_lbl", loss_lbl)
self.log("train/mse_loss_unl", loss_unl)
else:
# This route implies that we are just using supervised learning
y_pred = self(x_lbl)
loss = self.loss_lbl(y_pred, y_lbl_ohe.float())

self.log("train_loss", loss)
self.log("train/loss", loss)

# Evaluate train accuracy
with torch.no_grad():
y_pred = self.ema_model(x_lbl)
acc = accuracy(
y_pred, y_lbl, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("train_acc", acc, prog_bar=True)
self.log("train/acc", acc, prog_bar=True)
return loss

# PyTorch Lightning doesn't automatically no_grads the EMA step.
Expand All @@ -227,30 +225,31 @@ def training_step(self, batch, batch_idx):
def on_after_backward(self) -> None:
self.update_ema()

@staticmethod
def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram:
return wandb.Histogram(
torch.flatten(x).detach().cpu().tolist(),
num_bins=num_bins,
)

def validation_step(self, batch, batch_idx):
x, y = batch
self.lbl_logger(
self.logger.experiment,
"Val Input Y Label",
y,
flush_every=1,
num_bins=self.n_classes,
)
wandb.log({"val/y_lbl": self.wandb_hist(y, self.n_classes)})
y_pred = self.ema_model(x)
self.lbl_logger(
self.logger.experiment,
"Val Pred Y Label",
torch.argmax(y_pred, dim=1),
flush_every=1,
num_bins=self.n_classes,
wandb.log(
{
"val/y_lbl_pred": self.wandb_hist(
torch.argmax(y_pred, dim=1), self.n_classes
)
}
)
loss = F.cross_entropy(y_pred, y.long())

acc = accuracy(
y_pred, y, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("val_loss", loss)
self.log("val_acc", acc, prog_bar=True)
self.log("val/ce_loss", loss)
self.log("val/acc", acc, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
Expand All @@ -261,8 +260,8 @@ def test_step(self, batch, batch_idx):
acc = accuracy(
y_pred, y, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("test_loss", loss)
self.log("test_acc", acc, prog_bar=True)
self.log("test/ce_loss", loss)
self.log("test/acc", acc, prog_bar=True)
return loss

def predict_step(self, batch, *args, **kwargs) -> Any:
Expand Down Expand Up @@ -305,7 +304,7 @@ def x_trans_fn(x):

# Move Channel back to the second dimension
# B x H x W x C -> B x C x H x W
return (
return torch.nan_to_num(
torch.from_numpy(x_ss.reshape(b, h, w, c))
.permute(0, 3, 1, 2)
.float()
Expand Down Expand Up @@ -335,46 +334,11 @@ def y_trans_fn(y):
nan = ~torch.isnan(y_trans)
x_lab_trans = x_lab_trans[nan]
x_unl_trans = [x[nan] for x in x_unl_trans]
x_lab_trans = torch.nan_to_num(x_lab_trans)
x_unl_trans = [torch.nan_to_num(x) for x in x_unl_trans]
y_trans = y_trans[nan]

if self.training:
return (x_lab_trans, y_trans.long()), x_unl_trans
else:
return x_lab_trans, y_trans.long()


class WandBLabelLogger(dict):
"""Logger to log y labels to WandB"""

def __call__(
self,
logger: wandb.sdk.wandb_run.Run,
key: str,
value: torch.Tensor,
num_bins: int,
flush_every: int = 10,
):
"""Log the labels to WandB

Args:
logger: The W&B logger. Accessible through `self.logger.experiment`
key: The key to log the labels under.
value: The labels to log.
flush_every: How often to flush the labels to WandB.

"""
if key not in self.keys():
self[key] = [value]
else:
self[key].append(value)

if len(self[key]) % flush_every == 0:
logger.log(
{
key: wandb.Histogram(
torch.flatten(value).detach().cpu().tolist(),
num_bins=num_bins,
)
}
)
self[key] = []
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import pytest
import wandb

from frdc.load.dataset import FRDCDataset
from frdc.load.preset import FRDCDatasetPreset

wandb.init(mode="disabled")


@pytest.fixture(scope="session")
def ds() -> FRDCDataset:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def test_manual_segmentation_pipeline(ds):
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(m, datamodule=dm)

val_loss = trainer.validate(m, datamodule=dm)[0]["val_loss"]
val_loss = trainer.validate(m, datamodule=dm)[0]["val/ce_loss"]
logging.debug(f"Validation score: {val_loss:.2%}")
Loading
Loading