Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FRML-102 Implement W&B Label Dist Monitoring #42

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
import torch.nn.parallel
import torch.nn.parallel
import wandb
from lightning import LightningModule
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from torch.nn.functional import one_hot
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
self.sharpen_temp = sharpen_temp
self.mix_beta_alpha = mix_beta_alpha
self.save_hyperparameters()
self.lbl_logger = WandBLabelLogger()

@property
@abstractmethod
Expand Down Expand Up @@ -150,10 +152,12 @@ def progress(self):
) / self.trainer.max_epochs

def training_step(self, batch, batch_idx):
# Progress is a linear ramp from 0 to 1 over the course of training.
(x_lbl, y_lbl), x_unls = batch
self.lbl_logger(
self.logger.experiment, "Input Y Label", y_lbl, flush_every=10
)

y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes)
y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes)

# If x_unls is Truthy, then we are using MixMatch.
# Otherwise, we are just using supervised learning.
Expand All @@ -164,7 +168,7 @@ def training_step(self, batch, batch_idx):
y_unl = self.sharpen(y_unl, self.sharpen_temp)

x = torch.cat([x_lbl, *x_unls], dim=0)
y = torch.cat([y_lbl, *(y_unl,) * len(x_unls)], dim=0)
y = torch.cat([y_lbl_ohe, *(y_unl,) * len(x_unls)], dim=0)
x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha)

# This had interleaving, but it was removed as it's not
Expand All @@ -177,7 +181,19 @@ def training_step(self, batch, batch_idx):
y_mix_unl = y_mix[batch_size:]

loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl)
self.lbl_logger(
self.logger.experiment,
"Labelled Y Pred",
torch.argmax(y_mix_lbl_pred, dim=1),
flush_every=10,
)
loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl)
self.lbl_logger(
self.logger.experiment,
"Unlabelled Y Pred",
torch.argmax(y_mix_unl_pred, dim=1),
flush_every=10,
)
loss_unl_scale = self.loss_unl_scaler(progress=self.progress)

loss = loss_lbl + loss_unl * loss_unl_scale
Expand All @@ -188,7 +204,7 @@ def training_step(self, batch, batch_idx):
else:
# This route implies that we are just using supervised learning
y_pred = self(x_lbl)
loss = self.loss_lbl(y_pred, y_lbl.float())
loss = self.loss_lbl(y_pred, y_lbl_ohe.float())

self.log("train_loss", loss)
return loss
Expand All @@ -201,7 +217,16 @@ def on_after_backward(self) -> None:

def validation_step(self, batch, batch_idx):
x, y = batch
self.lbl_logger(
self.logger.experiment, "Val Input Y Label", y, flush_every=1
)
y_pred = self.ema_model(x)
self.lbl_logger(
self.logger.experiment,
"Val Pred Y Label",
torch.argmax(y_pred, dim=1),
flush_every=1,
)
loss = F.cross_entropy(y_pred, y.long())

acc = accuracy(
Expand Down Expand Up @@ -299,3 +324,38 @@ def y_trans_fn(y):
return (x_lab_trans, y_trans.long()), x_unl_trans
else:
return x_lab_trans, y_trans.long()


class WandBLabelLogger(dict):
"""Logger to log y labels to WandB"""

def __call__(
self,
logger: wandb.sdk.wandb_run.Run,
key: str,
value: torch.Tensor,
flush_every: int = 10,
):
"""Log the labels to WandB

Args:
logger: The W&B logger. Accessible through `self.logger.experiment`
key: The key to log the labels under.
value: The labels to log.
flush_every: How often to flush the labels to WandB.

"""
if key not in self.keys():
self[key] = [value]
else:
self[key].append(value)

if len(self[key]) % flush_every == 0:
logger.log(
{
key: wandb.Histogram(
torch.flatten(value).detach().cpu().tolist()
)
}
)
self[key] = []
9 changes: 5 additions & 4 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ def main(
):
run = wandb.init()
logger = WandbLogger(name="chestnut_dec_may", project="frdc")

# Prepare the dataset
train_lab_ds = ds.chestnut_20201218(transform=train_preprocess)

train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2)
)

val_ds = ds.chestnut_20210510_43m(transform=preprocess)

oe = OrdinalEncoder(
Expand All @@ -65,12 +64,12 @@ def main(
# Prepare the datamodule and trainer
dm = FRDCDataModule(
train_lab_ds=train_lab_ds,
# Pass in None to use the default supervised DM
train_unl_ds=train_unl_ds,
train_unl_ds=train_unl_ds, # None to use supervised DM
val_ds=val_ds,
batch_size=batch_size,
train_iters=train_iters,
val_iters=val_iters,
sampling_strategy="stratified",
)

trainer = pl.Trainer(
Expand All @@ -90,12 +89,14 @@ def main(
],
logger=logger,
)

m = InceptionV3MixMatchModule(
n_classes=n_classes,
lr=lr,
x_scaler=ss,
y_encoder=oe,
)
logger.watch(m)

trainer.fit(m, datamodule=dm)

Expand Down