Skip to content

Commit

Permalink
refactor: update imports to use lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Nov 11, 2024
1 parent 814d181 commit 4cae29b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 16 deletions.
2 changes: 1 addition & 1 deletion exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import matplotlib
import matplotlib.pyplot as plt
import torchvision.utils
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.loggers import WandbLogger

from rtpt import RTPT

Expand Down
22 changes: 8 additions & 14 deletions main_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
install()

import hydra
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch.utils.data
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import StochasticWeightAveraging, RichProgressBar
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.model_summary import (
ModelSummary,
)
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import StochasticWeightAveraging, RichProgressBar, ModelSummary
from lightning.pytorch.loggers import WandbLogger

from exp_utils import (
load_from_checkpoint,
Expand Down Expand Up @@ -128,16 +125,13 @@ def main(cfg: DictConfig):
logger.info("Initializing leaf distributions from data statistics")
init_einet_stats(model.spn, train_loader)

# Store number of model parameters
summary = ModelSummary(model, max_depth=-1)
logger.info("Model:")
logger.info(model)
logger.info("Summary:")
logger.info(summary)

# Setup callbacks
callbacks = []

# Store number of model parameters
summary = ModelSummary(max_depth=-1)
callbacks.append(summary)

# Add StochasticWeightAveraging callback
if cfg.swa:
swa_callback = StochasticWeightAveraging()
Expand Down
2 changes: 1 addition & 1 deletion models_pl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
from typing import Tuple

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torch.nn.parallel
import torch.utils.data
Expand Down

0 comments on commit 4cae29b

Please sign in to comment.