Skip to content

Commit

Permalink
Added Lightning checkpoints for model save and loss traj + prelims fo…
Browse files Browse the repository at this point in the history
…r stress in training loss
  • Loading branch information
ipcamit committed May 27, 2024
1 parent e8cb5fc commit 26edfce
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 14 deletions.
100 changes: 86 additions & 14 deletions kliff/trainer/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
else:
from torch_geometric.data.lightning import LightningDataset

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger

from kliff.dataset import Dataset
from kliff.utils import get_n_configs_in_xyz
Expand All @@ -39,6 +39,9 @@

import hashlib

from .torch_trainer_utils.lightning_checkpoints import SaveModelCallback, LossTrajectoryCallback
from pytorch_lightning.callbacks import EarlyStopping


class LightningTrainerWrapper(pl.LightningModule):
"""
Expand All @@ -60,6 +63,8 @@ def __init__(
n_workers=1,
energy_weight=1.0,
forces_weight=1.0,
lr_scheduler=None,
lr_scheduler_args=None,
):

super().__init__()
Expand All @@ -79,6 +84,9 @@ def __init__(
ema.to(device)
self.ema = ema

self.lr_scheduler = lr_scheduler
self.lr_scheduler_args = lr_scheduler_args

def forward(self, batch):
batch["coords"].requires_grad_(True)
model_inputs = {k: batch[k] for k in self.input_args}
Expand Down Expand Up @@ -108,15 +116,29 @@ def training_step(self, batch, batch_idx):
predicted_forces.squeeze(), target_forces.squeeze()
)
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
"train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
return loss

def configure_optimizers(self):
optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), lr=self.lr
)
return optimizer

if self.lr_scheduler:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler)(
optimizer, **self.lr_scheduler_args
)
return {"optimizer": optimizer,
"lr_scheduler":
{"scheduler": self.lr_scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "val_loss"
}
}
else:
return optimizer

def validation_step(self, batch, batch_idx):
torch.set_grad_enabled(True)
Expand All @@ -129,22 +151,22 @@ def validation_step(self, batch, batch_idx):

predicted_energy, predicted_forces = self.forward(batch)

per_atom_force_loss = torch.sum(
(predicted_forces.squeeze() - target_forces.squeeze()) ** 2, dim=1
)

loss = energy_weight * F.mse_loss(
predicted_energy.squeeze(), target_energy.squeeze()
) + forces_weight * F.mse_loss(
predicted_forces.squeeze(), target_forces.squeeze()
)
) + forces_weight * torch.mean(per_atom_force_loss)/3 # divide by 3 to get correct MSE

self.log(
"val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
return loss
return {"val_loss": loss, "per_atom_force_loss": per_atom_force_loss}

# def test_step(self, batch, batch_idx):
# pass
#
# def configure_optimizers(self):
# pass
#
# def setup_model(self):
# pass
#
Expand All @@ -168,8 +190,13 @@ def __init__(self, manifest, model):

super().__init__(manifest, model)

# loggers and callbacks
self.tb_logger = self._tb_logger()
self.csv_logger = self._csv_logger()
self.setup_dataloaders()
self.callbacks = self._get_callbacks()

# setup lightning trainer
self.pl_trainer = self._get_pl_trainer()

def setup_model(self):
Expand All @@ -180,6 +207,8 @@ def setup_model(self):
else:
ema_decay = None

scheduler = self.optimizer_manifest.get("lr_scheduler", {})

self.pl_model = LightningTrainerWrapper(
model=self.model,
input_args=self.model_manifest["input_args"],
Expand All @@ -192,6 +221,8 @@ def setup_model(self):
lr=self.optimizer_manifest["learning_rate"],
energy_weight=self.loss_manifest["weights"]["energy"],
forces_weight=self.loss_manifest["weights"]["forces"],
lr_scheduler=scheduler.get("name", None),
lr_scheduler_args=scheduler.get("args", None),
)

def train(self):
Expand Down Expand Up @@ -223,16 +254,57 @@ def setup_dataloaders(self):
logger.info("Data modules setup complete.")

def _tb_logger(self):
return TensorBoardLogger(self.current["run_dir"], name="lightning_logs")
return TensorBoardLogger(f"{self.current['run_dir']}/logs", name="lightning_logs")

def _csv_logger(self):
return CSVLogger(f"{self.current['run_dir']}/logs", name="csv_logs")

def _get_pl_trainer(self):
return pl.Trainer(
logger=self.tb_logger,
logger=[self.tb_logger, self.csv_logger],
max_epochs=self.optimizer_manifest["epochs"],
accelerator="auto",
strategy="auto",
strategy="ddp",
callbacks=self.callbacks,
)

def _get_callbacks(self):
callbacks = []

ckpt_dir = f"{self.current['run_dir']}/checkpoints"
ckpt_interval = self.training_manifest.get("ckpt_interval", 50)
save_model_callback = SaveModelCallback(ckpt_dir, ckpt_interval)
callbacks.append(save_model_callback)
logger.info("Checkpointing setup complete.")

if self.training_manifest.get("early_stopping", False):
patience = self.training_manifest["early_stopping"].get("patience", 10)
if not isinstance(patience, int):
raise TrainerError(
f"Early stopping should be an integer, got {patience}"
)
delta = self.training_manifest["early_stopping"].get("delta", 1e-3)
early_stopping = EarlyStopping(
monitor="val_loss",
patience=patience,
mode="min",
min_delta=delta,
)
callbacks.append(early_stopping)
logger.info("Early stopping setup complete.")

if self.loss_manifest.get("loss_traj", False):
loss_traj_folder = f"{self.current['run_dir']}/loss_trajectory"
loss_idxs = self.dataset_sample_manifest["val_indices"]
ckpt_interval = self.training_manifest.get("ckpt_interval", 10)
loss_trajectory_callback = LossTrajectoryCallback(
loss_traj_folder, loss_idxs, ckpt_interval
)
callbacks.append(loss_trajectory_callback)
logger.info("Loss trajectory setup complete.")

return callbacks

def setup_optimizer(self):
# Not needed as Pytorch Lightning handles the optimizer
pass
66 changes: 66 additions & 0 deletions kliff/trainer/torch_trainer_utils/lightning_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import pytorch_lightning as pl
import os
from kliff.dataset import Dataset
import dill


class SaveModelCallback(pl.Callback):
"""
Callback to save the model at the end of each epoch. The model is saved in the ckpt_dir with the name
"last_model.pth". The best model is saved with the name "best_model.pth". The model is saved every
ckpt_interval epochs with the name "epoch_{epoch}.pth".
"""
def __init__(self, ckpt_dir, ckpt_interval=100):
super().__init__()
self.ckpt_dir = ckpt_dir
self.best_val_loss = float("inf")
self.ckpt_interval = ckpt_interval
os.makedirs(self.ckpt_dir, exist_ok=True)

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
# Save the last model
last_save_path = os.path.join(self.ckpt_dir, "last_model.pth")
torch.save(pl_module.state_dict(), last_save_path)

# Save the best model
if trainer.callback_metrics.get("val_loss") < self.best_val_loss:
self.best_val_loss = trainer.callback_metrics["val_loss"]
best_save_path = os.path.join(self.ckpt_dir, "best_model.pth")
torch.save(pl_module.state_dict(), best_save_path)

# Save the model every ckpt_interval epochs
if pl_module.current_epoch % self.ckpt_interval == 0:
epoch_save_path = os.path.join(self.ckpt_dir, f"epoch_{pl_module.current_epoch}.pth")
torch.save(pl_module.state_dict(), epoch_save_path)


class LossTrajectoryCallback(pl.Callback):
"""
Callback to save the loss trajectory of the model during validation. The loss trajectory is saved in the
loss_traj_file. The loss trajectory is saved every ckpt_interval epochs. Currently, it only logs per atom force loss.
"""
def __init__(self, loss_traj_folder, val_dataset: Dataset, ckpt_interval=10):
super().__init__()
self.loss_traj_folder = loss_traj_folder
self.ckpt_interval = ckpt_interval

os.makedirs(self.loss_traj_folder, exist_ok=True)
with open(os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "w") as f:
f.write("epoch,loss\n")

dill.dump(val_dataset, open(os.path.join(self.loss_traj_folder, "val_dataset.pkl"), "wb"))
self.val_losses = []

def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0):
if trainer.current_epoch % self.ckpt_interval == 0:
val_force_loss = outputs["per_atom_force_loss"].detach().cpu().numpy()
self.val_losses.extend(val_force_loss)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.current_epoch % self.ckpt_interval == 0:
with open(os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "a") as f:
loss_str = ",".join([str(trainer.current_epoch)] + [f"{loss:.5f}" for loss in self.val_losses])
f.write(f"{loss_str}\n")
self.val_losses = []

Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self):
self.images = None
self.species = None
self.z = None
self.cell = None
self.contributions = None

def __inc__(self, key: str, value: torch.Tensor, *args, **kwargs):
Expand Down Expand Up @@ -164,6 +165,7 @@ def to_py_graph(graph: graph_module.GraphData) -> PyGGraph:
pyg_graph.images = torch.as_tensor(graph.images)
pyg_graph.species = torch.as_tensor(graph.species)
pyg_graph.z = torch.as_tensor(graph.z)
pyg_graph.cell = torch.as_tensor(graph.cell)
pyg_graph.contributions = torch.as_tensor(graph.contributions)
pyg_graph.num_nodes = torch.as_tensor(graph.n_nodes)
for i in range(graph.n_layers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct GraphData{
py::array_t<int64_t> images; // periodic images of the atoms
py::array_t<int64_t> species; // species index of the atoms
py::array_t<int64_t> z; // atomic number of the atoms
py::array_t<double> cell; // cell of the system
py::array_t<int64_t> contributions; // contributing of the atoms to the energy
};

Expand Down Expand Up @@ -260,6 +261,9 @@ GraphData get_complete_graph(
= py::array_t<int64_t>(n_atoms + n_pad, need_neighbors_64.data());

gs.n_nodes = n_atoms + n_pad;

gs.cell = py::array_t<double>(cell.size(), cell.data());

delete[] padded_coords;
delete[] need_neighbors;
for (int i = 0; i < n_graph_layers; i++) { delete[] graph_edge_indices[i]; }
Expand All @@ -282,6 +286,7 @@ PYBIND11_MODULE(graph_module, m)
.def_readwrite("images", &GraphData::images)
.def_readwrite("species", &GraphData::species)
.def_readwrite("z", &GraphData::z)
.def_readwrite("cell", &GraphData::cell)
.def_readwrite("contributions", &GraphData::contributions);
m.def("get_complete_graph",
&get_complete_graph,
Expand Down

0 comments on commit 26edfce

Please sign in to comment.