Skip to content

Commit

Permalink
Merge pull request #45 from FR-DC/frml-107
Browse files Browse the repository at this point in the history
FRML-107 Fix visual issue with W&B using too many bins for Y Prediction histograms
  • Loading branch information
Eve-ning authored Jan 8, 2024
2 parents fb93890 + c8b050a commit c2c48b3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
2 changes: 0 additions & 2 deletions src/frdc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,3 @@
f"LABEL_STUDIO_CLIENT will be None."
)
LABEL_STUDIO_CLIENT = None


19 changes: 16 additions & 3 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def progress(self):
def training_step(self, batch, batch_idx):
(x_lbl, y_lbl), x_unls = batch
self.lbl_logger(
self.logger.experiment, "Input Y Label", y_lbl, flush_every=10
self.logger.experiment,
"Input Y Label",
y_lbl,
flush_every=10,
num_bins=self.n_classes,
)

y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes)
Expand Down Expand Up @@ -186,13 +190,15 @@ def training_step(self, batch, batch_idx):
"Labelled Y Pred",
torch.argmax(y_mix_lbl_pred, dim=1),
flush_every=10,
num_bins=self.n_classes,
)
loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl)
self.lbl_logger(
self.logger.experiment,
"Unlabelled Y Pred",
torch.argmax(y_mix_unl_pred, dim=1),
flush_every=10,
num_bins=self.n_classes,
)
loss_unl_scale = self.loss_unl_scaler(progress=self.progress)

Expand All @@ -218,14 +224,19 @@ def on_after_backward(self) -> None:
def validation_step(self, batch, batch_idx):
x, y = batch
self.lbl_logger(
self.logger.experiment, "Val Input Y Label", y, flush_every=1
self.logger.experiment,
"Val Input Y Label",
y,
flush_every=1,
num_bins=self.n_classes,
)
y_pred = self.ema_model(x)
self.lbl_logger(
self.logger.experiment,
"Val Pred Y Label",
torch.argmax(y_pred, dim=1),
flush_every=1,
num_bins=self.n_classes,
)
loss = F.cross_entropy(y_pred, y.long())

Expand Down Expand Up @@ -334,6 +345,7 @@ def __call__(
logger: wandb.sdk.wandb_run.Run,
key: str,
value: torch.Tensor,
num_bins: int,
flush_every: int = 10,
):
"""Log the labels to WandB
Expand All @@ -354,7 +366,8 @@ def __call__(
logger.log(
{
key: wandb.Histogram(
torch.flatten(value).detach().cpu().tolist()
torch.flatten(value).detach().cpu().tolist(),
num_bins=num_bins,
)
}
)
Expand Down
14 changes: 7 additions & 7 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This test is done by training a model on the 20201218 dataset, then testing on
the 20210510 dataset.
"""
import os

# Uncomment this to run the W&B monitoring locally
# import os
Expand Down Expand Up @@ -41,9 +42,6 @@ def main(
val_iters=15,
lr=1e-3,
):
run = wandb.init()
logger = WandbLogger(name="chestnut_dec_may", project="frdc")

# Prepare the dataset
train_lab_ds = ds.chestnut_20201218(transform=train_preprocess)
train_unl_ds = ds.chestnut_20201218.unlabelled(
Expand Down Expand Up @@ -87,7 +85,9 @@ def main(
monitor="val_loss", mode="min", save_top_k=1
),
],
logger=logger,
logger=(
logger := WandbLogger(name="chestnut_dec_may", project="frdc")
),
)

m = InceptionV3MixMatchModule(
Expand All @@ -103,7 +103,7 @@ def main(
with open(Path(__file__).parent / "report.md", "w") as f:
f.write(
f"# Chestnut Nature Park (Dec 2020 vs May 2021)\n"
f"- Results: [WandB Report]({run.get_url()})"
f"- Results: [WandB Report]({wandb.run.get_url()})"
)

y_true, y_pred = predict(
Expand Down Expand Up @@ -133,8 +133,8 @@ def main(
VAL_ITERS = 15
LR = 1e-3

assert wandb.run is None
wandb.setup(wandb.Settings(program=__name__, program_relpath=__name__))
wandb.login(key=os.environ["WANDB_API_KEY"])

main(
batch_size=BATCH_SIZE,
epochs=EPOCHS,
Expand Down

0 comments on commit c2c48b3

Please sign in to comment.