Skip to content

Commit

Permalink
fix_flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Sentz98 committed Jun 28, 2024
1 parent dba6286 commit 95579d0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"lr": 0.001, # this is ignored if you are overriding the configure_optimizers
"debug": False,
"log_wandb": False,
"wandb_resume": 'auto' # Resume run if prev crashed, otherwise new run. ["allow", "must", "never", "auto" or None]
"wandb_resume": "auto", # ["allow", "must", "never", "auto" or None]
}


Expand Down Expand Up @@ -384,7 +384,7 @@ def compute_macs(self, input_shape: Union[List, Tuple]):

def on_train_start(self):
"""Initializes the optimizer, modules and puts the networks on the right
devices. Optionally loads checkpoint if already present. It also start wandb
devices. Optionally loads checkpoint if already present. It also start wandb
logger if selected.
This function gets executed at the beginning of every training.
Expand All @@ -397,11 +397,11 @@ def on_train_start(self):
import wandb

self.wlog = wandb.init(
project=self.hparams.project_name,
project=self.hparams.project_name,
name=self.hparams.experiment_name,
resume=self.hparams.wandb_resume,
id=self.hparams.experiment_name,
config=self.hparams
config=self.hparams,
)

init_opt = self.configure_optimizers()
Expand Down Expand Up @@ -580,8 +580,8 @@ def train(

train_metrics.update({"train_loss": loss_epoch / (idx + 1)})

if self.hparams.log_wandb: # wandb log train loss
self.wlog.log(train_metrics)
if self.hparams.log_wandb: # wandb log train loss
self.wlog.log(train_metrics)

if "val" in datasets:
val_metrics = self.validate()
Expand All @@ -597,7 +597,7 @@ def train(
else:
val_metrics = train_metrics.update({"val_loss": loss_epoch / (idx + 1)})

if self.hparams.log_wandb: # wandb log val loss
if self.hparams.log_wandb: # wandb log val loss
self.wlog.log(val_metrics)

if e >= 1 and self.debug:
Expand Down

0 comments on commit 95579d0

Please sign in to comment.