Skip to content

Commit

Permalink
add wandb log
Browse files Browse the repository at this point in the history
  • Loading branch information
Sentz98 committed Jun 28, 2024
1 parent 8ba8912 commit dba6286
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ results/
ckp/
checkpoints/
*.swp
wandb/

Dockerfile
build_dgx.sh
Expand Down
28 changes: 27 additions & 1 deletion micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@

# This is used ONLY if you are not using argparse to get the hparams
default_cfg = {
"project_name": "micromind",
"output_folder": "results",
"experiment_name": "micromind_exp",
"opt": "adam", # this is ignored if you are overriding the configure_optimizers
"lr": 0.001, # this is ignored if you are overriding the configure_optimizers
"debug": False,
"log_wandb": False,
"wandb_resume": 'auto' # Resume run if prev crashed, otherwise new run. ["allow", "must", "never", "auto" or None]
}


Expand Down Expand Up @@ -381,14 +384,26 @@ def compute_macs(self, input_shape: Union[List, Tuple]):

def on_train_start(self):
"""Initializes the optimizer, modules and puts the networks on the right
devices. Optionally loads checkpoint if already present.
devices. Optionally loads checkpoint if already present. It also start wandb
logger if selected.
This function gets executed at the beginning of every training.
"""

# pass debug status to checkpointer
self.checkpointer.debug = self.hparams.debug

if self.hparams.log_wandb:
import wandb

self.wlog = wandb.init(
project=self.hparams.project_name,
name=self.hparams.experiment_name,
resume=self.hparams.wandb_resume,
id=self.hparams.experiment_name,
config=self.hparams
)

init_opt = self.configure_optimizers()
if isinstance(init_opt, list) or isinstance(init_opt, tuple):
self.opt, self.lr_sched = init_opt
Expand Down Expand Up @@ -449,6 +464,8 @@ def init_devices(self):

def on_train_end(self):
"""Runs at the end of each training. Cleans up before exiting."""
if self.hparams.log_wandb:
self.wlog.finish()
pass

def eval(self):
Expand Down Expand Up @@ -531,6 +548,9 @@ def train(
# ok for cos_lr
self.lr_sched.step()

if self.hparams.log_wandb:
self.wlog.log({"lr": self.lr_sched.get_last_lr()})

for m in self.metrics:
if (
self.current_epoch + 1
Expand Down Expand Up @@ -560,6 +580,9 @@ def train(

train_metrics.update({"train_loss": loss_epoch / (idx + 1)})

if self.hparams.log_wandb: # wandb log train loss
self.wlog.log(train_metrics)

if "val" in datasets:
val_metrics = self.validate()
if (
Expand All @@ -574,6 +597,9 @@ def train(
else:
val_metrics = train_metrics.update({"val_loss": loss_epoch / (idx + 1)})

if self.hparams.log_wandb: # wandb log val loss
self.wlog.log(val_metrics)

if e >= 1 and self.debug:
break

Expand Down

0 comments on commit dba6286

Please sign in to comment.