Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FRRD-80 Add Visualization Notebook for FRRD-77 #67

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions notebooks/README.md → notebooks/frrd-60/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Jupyter Notebooks

## FRRD-60: Consistency as a feature discriminator for Novelty Detection

Hypothesis: Consistency can be used as a measure to discriminate between normal
Expand Down
1 change: 1 addition & 0 deletions notebooks/frrd-80/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
!*.ckpt
5 changes: 5 additions & 0 deletions notebooks/frrd-80/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Jupyter Notebooks

## FRRD-80: Visualization of FixMatch

This simply shows the decision boundaries constructed by the FixMatch model.
Binary file added notebooks/frrd-80/epoch=0-step=150.ckpt
Binary file not shown.
Binary file added notebooks/frrd-80/epoch=1-step=300.ckpt
Binary file not shown.
Binary file added notebooks/frrd-80/epoch=2-step=450.ckpt
Binary file not shown.
Binary file added notebooks/frrd-80/epoch=3-step=600.ckpt
Binary file not shown.
Binary file added notebooks/frrd-80/epoch=4-step=750.ckpt
Binary file not shown.
Binary file added notebooks/frrd-80/epoch=5-step=900.ckpt
Binary file not shown.
54,255 changes: 54,255 additions & 0 deletions notebooks/frrd-80/nb.ipynb

Large diffs are not rendered by default.

32 changes: 31 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pre-commit = "^3.5.0"
black = "^23.10.0"
flake8 = "^6.1.0"
wandb = "^0.16.0"
plotly = "^5.22.0"



Expand Down
129 changes: 85 additions & 44 deletions src/frdc/train/fixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,103 +53,144 @@ def __init__(
self.n_classes = n_classes
self.unl_conf_threshold = unl_conf_threshold
self.save_hyperparameters()

# We disable's PyTorch Lightning's auto backward during training.
# See why in `self.training_step` docstring
self.automatic_optimization = False

@abstractmethod
def forward(self, x):
...

@staticmethod
def loss_lbl(lbl_pred: torch.Tensor, lbl: torch.Tensor):
return F.cross_entropy(lbl_pred, lbl)

@staticmethod
def loss_unl(unl_pred: torch.Tensor, unl: torch.Tensor):
return F.cross_entropy(unl_pred, unl)
def training_step(self, batch, batch_idx):
"""A single training step for a batch

Notes:
As mentioned in __init__, we manually back propagate for
performance reasons. When we loop through `x_unls` (unlabelled
batches), the gradient tape accumulates unnecessarily.
We actually just need to back propagate every batch in `x_unls`,
thus cutting the tape shorter.

Every tape is bounded by the code:

>>> opt.zero_grad()
>>> # Steps that requires grad
>>> loss = ...
>>> self.manual_backward(loss)
>>> opt.step()

The losses are defined as follows:

ℓ_lbl 1
Labelled loss: --- Σ CE(y_true, y_weak)
n ∀ lbl

ℓ_unl 1 if : y_weak > t,
Unlabelled loss: --- Σ { then: CE(y_strong, y_weak) }
n ∀ unl else: 0
Loss: ℓ_lbl + ℓ_unl
"""
def training_step(self, batch, batch_idx):
(x_lbl, y_lbl), x_unls = batch
opt = self.optimizers()
opt.zero_grad()

self.log("train/x_lbl_mean", x_lbl.mean())
self.log("train/x_lbl_stdev", x_lbl.std())

wandb.log({"train/x_lbl": wandb_hist(y_lbl, self.n_classes)})
# Backprop for labelled data
opt.zero_grad()
loss_lbl = F.cross_entropy((y_lbl_pred := self(x_lbl)), y_lbl.long())
self.manual_backward(loss_lbl)
opt.step()

wandb.log(
{
"train/y_lbl_pred": wandb_hist(
torch.argmax(y_lbl_pred, dim=1), self.n_classes
)
}
)
# This is only for logging purposes
loss_unl = 0

# Backprop for unlabelled data
for x_weak, x_strong in x_unls:
opt.zero_grad()
self.log("train/x0_unl_mean", x_weak[0].mean())
self.log("train/x0_unl_stdev", x_weak[0].std())
with torch.no_grad():
y_pred_weak = self(x_weak)
y_pred_weak_max, y_pred_weak_max_ix = torch.max(
y_pred_weak, dim=1
)
is_confident = y_pred_weak_max >= self.unl_conf_threshold

y_pred_strong = self(x_strong[is_confident])
# Test if y_weak is over the threshold
# if so, include into the loss
# else, we simply mask it out
with torch.no_grad():
y_weak = self(x_weak)
y_weak_max, y_weak_max_ix = torch.max(y_weak, dim=1)
is_confident = y_weak_max >= self.unl_conf_threshold

y_strong = self(x_strong[is_confident])

# CE only on the masked out samples
# We perform `reduction="sum"` so that we "include" the masked out
# samples by fixing the denominator.
# E.g.
# y_weak > t = [T, F, T, F]
# Losses = [1, 2, 3, 4]
# Masked Losses = [1, 3, ]
# Incorrect CE Mean = (1 + 3) / 2
# Correct CE Mean = (1 + 3) / 4
batch_size = x_lbl.shape[0]
loss_unl_i = F.cross_entropy(
y_pred_strong,
y_pred_weak_max_ix[is_confident],
y_strong,
y_weak_max_ix[is_confident],
reduction="sum",
) / (len(x_unls) * x_lbl.shape[0])
) / (len(x_unls) * batch_size)

self.manual_backward(loss_unl_i)
opt.step()

loss_unl += loss_unl_i.detach().item()

self.log("train/x0_unl_mean", x_weak[0].mean())
self.log("train/x0_unl_stdev", x_weak[0].std())
wandb.log(
{
"train/y_unl_pred": wandb_hist(
torch.argmax(y_pred_strong, dim=1), self.n_classes
torch.argmax(y_strong, dim=1), self.n_classes
)
}
)

self.log("train/ce_loss_lbl", loss_lbl)
self.log("train/ce_loss_unl", loss_unl)
self.log("train/loss", loss_lbl + loss_unl)

# Evaluate train accuracy
with torch.no_grad():
y_pred = self(x_lbl)
acc = accuracy(
y_pred, y_lbl, task="multiclass", num_classes=y_pred.shape[1]
)
self.log("train/acc", acc, prog_bar=True)

self.log("train/x_lbl_mean", x_lbl.mean())
self.log("train/x_lbl_stdev", x_lbl.std())
wandb.log({"train/x_lbl": wandb_hist(y_lbl, self.n_classes)})
self.log("train/ce_loss_lbl", loss_lbl)
self.log("train/ce_loss_unl", loss_unl)
self.log("train/loss", loss_lbl + loss_unl)
self.log("train/acc", acc, prog_bar=True)

wandb.log(
{
"train/y_lbl_pred": wandb_hist(
torch.argmax(y_lbl_pred, dim=1), self.n_classes
)
}
)

def validation_step(self, batch, batch_idx):
# The batch outputs x_unls due to our on_before_batch_transfer
(x, y), _x_unls = batch
wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)})
y_pred = self(x)
loss = F.cross_entropy(y_pred, y.long())
acc = accuracy(
y_pred, y, task="multiclass", num_classes=self.n_classes
)

wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)})
wandb.log(
{
"val/y_lbl_pred": wandb_hist(
torch.argmax(y_pred, dim=1), self.n_classes
)
}
)
loss = F.cross_entropy(y_pred, y.long())

acc = accuracy(
y_pred, y, task="multiclass", num_classes=y_pred.shape[1]
)

self.log("val/ce_loss", loss)
self.log("val/acc", acc, prog_bar=True)
return loss
Expand All @@ -161,7 +202,7 @@ def test_step(self, batch, batch_idx):
loss = F.cross_entropy(y_pred, y.long())

acc = accuracy(
y_pred, y, task="multiclass", num_classes=y_pred.shape[1]
y_pred, y, task="multiclass", num_classes=self.n_classes
)
self.log("test/ce_loss", loss)
self.log("test/acc", acc, prog_bar=True)
Expand Down
Loading