Skip to content

Commit

Permalink
Fix uninitialized unl for supervised learning
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Feb 21, 2024
1 parent 33b9108 commit b74c406
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def training_step(self, batch, batch_idx):

self.log("train/x_lbl_mean", x_lbl.mean())
self.log("train/x_lbl_stdev", x_lbl.std())
self.log("train/x0_unl_mean", x_unls[0].mean())
self.log("train/x0_unl_stdev", x_unls[0].std())

wandb.log({"train/x_lbl": self.wandb_hist(y_lbl, self.n_classes)})
y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes)
Expand All @@ -163,6 +161,8 @@ def training_step(self, batch, batch_idx):
# Otherwise, we are just using supervised learning.
if x_unls:
# This route implies that we are using SSL
self.log("train/x0_unl_mean", x_unls[0].mean())
self.log("train/x0_unl_stdev", x_unls[0].std())
with torch.no_grad():
y_unl = self.guess_labels(x_unls=x_unls)
y_unl = self.sharpen(y_unl, self.sharpen_temp)
Expand Down

0 comments on commit b74c406

Please sign in to comment.