diff --git a/exp_utils.py b/exp_utils.py index e4f4403..7bb3432 100644 --- a/exp_utils.py +++ b/exp_utils.py @@ -9,7 +9,7 @@ import matplotlib import matplotlib.pyplot as plt import torchvision.utils -from pytorch_lightning.loggers import WandbLogger +from lightning.pytorch.loggers import WandbLogger from rtpt import RTPT diff --git a/main_pl.py b/main_pl.py index abc709c..b6b687f 100644 --- a/main_pl.py +++ b/main_pl.py @@ -11,14 +11,11 @@ install() import hydra -import pytorch_lightning as pl +import lightning.pytorch as pl import torch.utils.data -from pytorch_lightning import seed_everything -from pytorch_lightning.callbacks import StochasticWeightAveraging, RichProgressBar -from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.utilities.model_summary import ( - ModelSummary, -) +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import StochasticWeightAveraging, RichProgressBar, ModelSummary +from lightning.pytorch.loggers import WandbLogger from exp_utils import ( load_from_checkpoint, @@ -128,16 +125,13 @@ def main(cfg: DictConfig): logger.info("Initializing leaf distributions from data statistics") init_einet_stats(model.spn, train_loader) - # Store number of model parameters - summary = ModelSummary(model, max_depth=-1) - logger.info("Model:") - logger.info(model) - logger.info("Summary:") - logger.info(summary) - # Setup callbacks callbacks = [] + # Store number of model parameters + summary = ModelSummary(max_depth=-1) + callbacks.append(summary) + # Add StochasticWeightAveraging callback if cfg.swa: swa_callback = StochasticWeightAveraging() diff --git a/models_pl.py b/models_pl.py index 8b58c86..57fcfeb 100644 --- a/models_pl.py +++ b/models_pl.py @@ -1,7 +1,7 @@ from abc import ABC from typing import Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.parallel import torch.utils.data