diff --git a/kliff/trainer/lightning_trainer.py b/kliff/trainer/lightning_trainer.py index ec08c83..1f80228 100644 --- a/kliff/trainer/lightning_trainer.py +++ b/kliff/trainer/lightning_trainer.py @@ -23,7 +23,7 @@ else: from torch_geometric.data.lightning import LightningDataset -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from kliff.dataset import Dataset from kliff.utils import get_n_configs_in_xyz @@ -39,6 +39,9 @@ import hashlib +from .torch_trainer_utils.lightning_checkpoints import SaveModelCallback, LossTrajectoryCallback +from pytorch_lightning.callbacks import EarlyStopping + class LightningTrainerWrapper(pl.LightningModule): """ @@ -60,6 +63,8 @@ def __init__( n_workers=1, energy_weight=1.0, forces_weight=1.0, + lr_scheduler=None, + lr_scheduler_args=None, ): super().__init__() @@ -79,6 +84,9 @@ def __init__( ema.to(device) self.ema = ema + self.lr_scheduler = lr_scheduler + self.lr_scheduler_args = lr_scheduler_args + def forward(self, batch): batch["coords"].requires_grad_(True) model_inputs = {k: batch[k] for k in self.input_args} @@ -108,7 +116,7 @@ def training_step(self, batch, batch_idx): predicted_forces.squeeze(), target_forces.squeeze() ) self.log( - "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True ) return loss @@ -116,7 +124,21 @@ def configure_optimizers(self): optimizer = getattr(torch.optim, self.optimizer_name)( self.model.parameters(), lr=self.lr ) - return optimizer + + if self.lr_scheduler: + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler)( + optimizer, **self.lr_scheduler_args + ) + return {"optimizer": optimizer, + "lr_scheduler": + {"scheduler": self.lr_scheduler, + "interval": "epoch", + "frequency": 1, + "monitor": "val_loss" + } + } + else: + return optimizer def validation_step(self, batch, batch_idx): torch.set_grad_enabled(True) @@ -129,22 +151,22 @@ def validation_step(self, batch, batch_idx): predicted_energy, predicted_forces = self.forward(batch) + per_atom_force_loss = torch.sum( + (predicted_forces.squeeze() - target_forces.squeeze()) ** 2, dim=1 + ) + loss = energy_weight * F.mse_loss( predicted_energy.squeeze(), target_energy.squeeze() - ) + forces_weight * F.mse_loss( - predicted_forces.squeeze(), target_forces.squeeze() - ) + ) + forces_weight * torch.mean(per_atom_force_loss)/3 # divide by 3 to get correct MSE + self.log( - "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True ) - return loss + return {"val_loss": loss, "per_atom_force_loss": per_atom_force_loss} # def test_step(self, batch, batch_idx): # pass # - # def configure_optimizers(self): - # pass - # # def setup_model(self): # pass # @@ -168,8 +190,13 @@ def __init__(self, manifest, model): super().__init__(manifest, model) + # loggers and callbacks self.tb_logger = self._tb_logger() + self.csv_logger = self._csv_logger() self.setup_dataloaders() + self.callbacks = self._get_callbacks() + + # setup lightning trainer self.pl_trainer = self._get_pl_trainer() def setup_model(self): @@ -180,6 +207,8 @@ def setup_model(self): else: ema_decay = None + scheduler = self.optimizer_manifest.get("lr_scheduler", {}) + self.pl_model = LightningTrainerWrapper( model=self.model, input_args=self.model_manifest["input_args"], @@ -192,6 +221,8 @@ def setup_model(self): lr=self.optimizer_manifest["learning_rate"], energy_weight=self.loss_manifest["weights"]["energy"], forces_weight=self.loss_manifest["weights"]["forces"], + lr_scheduler=scheduler.get("name", None), + lr_scheduler_args=scheduler.get("args", None), ) def train(self): @@ -223,16 +254,57 @@ def setup_dataloaders(self): logger.info("Data modules setup complete.") def _tb_logger(self): - return TensorBoardLogger(self.current["run_dir"], name="lightning_logs") + return TensorBoardLogger(f"{self.current['run_dir']}/logs", name="lightning_logs") + + def _csv_logger(self): + return CSVLogger(f"{self.current['run_dir']}/logs", name="csv_logs") def _get_pl_trainer(self): return pl.Trainer( - logger=self.tb_logger, + logger=[self.tb_logger, self.csv_logger], max_epochs=self.optimizer_manifest["epochs"], accelerator="auto", - strategy="auto", + strategy="ddp", + callbacks=self.callbacks, ) + def _get_callbacks(self): + callbacks = [] + + ckpt_dir = f"{self.current['run_dir']}/checkpoints" + ckpt_interval = self.training_manifest.get("ckpt_interval", 50) + save_model_callback = SaveModelCallback(ckpt_dir, ckpt_interval) + callbacks.append(save_model_callback) + logger.info("Checkpointing setup complete.") + + if self.training_manifest.get("early_stopping", False): + patience = self.training_manifest["early_stopping"].get("patience", 10) + if not isinstance(patience, int): + raise TrainerError( + f"Early stopping should be an integer, got {patience}" + ) + delta = self.training_manifest["early_stopping"].get("delta", 1e-3) + early_stopping = EarlyStopping( + monitor="val_loss", + patience=patience, + mode="min", + min_delta=delta, + ) + callbacks.append(early_stopping) + logger.info("Early stopping setup complete.") + + if self.loss_manifest.get("loss_traj", False): + loss_traj_folder = f"{self.current['run_dir']}/loss_trajectory" + loss_idxs = self.dataset_sample_manifest["val_indices"] + ckpt_interval = self.training_manifest.get("ckpt_interval", 10) + loss_trajectory_callback = LossTrajectoryCallback( + loss_traj_folder, loss_idxs, ckpt_interval + ) + callbacks.append(loss_trajectory_callback) + logger.info("Loss trajectory setup complete.") + + return callbacks + def setup_optimizer(self): # Not needed as Pytorch Lightning handles the optimizer pass diff --git a/kliff/trainer/torch_trainer_utils/lightning_checkpoints.py b/kliff/trainer/torch_trainer_utils/lightning_checkpoints.py new file mode 100644 index 0000000..f307272 --- /dev/null +++ b/kliff/trainer/torch_trainer_utils/lightning_checkpoints.py @@ -0,0 +1,66 @@ +import torch +import pytorch_lightning as pl +import os +from kliff.dataset import Dataset +import dill + + +class SaveModelCallback(pl.Callback): + """ + Callback to save the model at the end of each epoch. The model is saved in the ckpt_dir with the name + "last_model.pth". The best model is saved with the name "best_model.pth". The model is saved every + ckpt_interval epochs with the name "epoch_{epoch}.pth". + """ + def __init__(self, ckpt_dir, ckpt_interval=100): + super().__init__() + self.ckpt_dir = ckpt_dir + self.best_val_loss = float("inf") + self.ckpt_interval = ckpt_interval + os.makedirs(self.ckpt_dir, exist_ok=True) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + # Save the last model + last_save_path = os.path.join(self.ckpt_dir, "last_model.pth") + torch.save(pl_module.state_dict(), last_save_path) + + # Save the best model + if trainer.callback_metrics.get("val_loss") < self.best_val_loss: + self.best_val_loss = trainer.callback_metrics["val_loss"] + best_save_path = os.path.join(self.ckpt_dir, "best_model.pth") + torch.save(pl_module.state_dict(), best_save_path) + + # Save the model every ckpt_interval epochs + if pl_module.current_epoch % self.ckpt_interval == 0: + epoch_save_path = os.path.join(self.ckpt_dir, f"epoch_{pl_module.current_epoch}.pth") + torch.save(pl_module.state_dict(), epoch_save_path) + + +class LossTrajectoryCallback(pl.Callback): + """ + Callback to save the loss trajectory of the model during validation. The loss trajectory is saved in the + loss_traj_file. The loss trajectory is saved every ckpt_interval epochs. Currently, it only logs per atom force loss. + """ + def __init__(self, loss_traj_folder, val_dataset: Dataset, ckpt_interval=10): + super().__init__() + self.loss_traj_folder = loss_traj_folder + self.ckpt_interval = ckpt_interval + + os.makedirs(self.loss_traj_folder, exist_ok=True) + with open(os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "w") as f: + f.write("epoch,loss\n") + + dill.dump(val_dataset, open(os.path.join(self.loss_traj_folder, "val_dataset.pkl"), "wb")) + self.val_losses = [] + + def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0): + if trainer.current_epoch % self.ckpt_interval == 0: + val_force_loss = outputs["per_atom_force_loss"].detach().cpu().numpy() + self.val_losses.extend(val_force_loss) + + def on_validation_epoch_end(self, trainer, pl_module): + if trainer.current_epoch % self.ckpt_interval == 0: + with open(os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "a") as f: + loss_str = ",".join([str(trainer.current_epoch)] + [f"{loss:.5f}" for loss in self.val_losses]) + f.write(f"{loss_str}\n") + self.val_losses = [] + diff --git a/kliff/transforms/configuration_transforms/graphs/generate_graph.py b/kliff/transforms/configuration_transforms/graphs/generate_graph.py index 9496512..930308f 100644 --- a/kliff/transforms/configuration_transforms/graphs/generate_graph.py +++ b/kliff/transforms/configuration_transforms/graphs/generate_graph.py @@ -46,6 +46,7 @@ def __init__(self): self.images = None self.species = None self.z = None + self.cell = None self.contributions = None def __inc__(self, key: str, value: torch.Tensor, *args, **kwargs): @@ -164,6 +165,7 @@ def to_py_graph(graph: graph_module.GraphData) -> PyGGraph: pyg_graph.images = torch.as_tensor(graph.images) pyg_graph.species = torch.as_tensor(graph.species) pyg_graph.z = torch.as_tensor(graph.z) + pyg_graph.cell = torch.as_tensor(graph.cell) pyg_graph.contributions = torch.as_tensor(graph.contributions) pyg_graph.num_nodes = torch.as_tensor(graph.n_nodes) for i in range(graph.n_layers): diff --git a/kliff/transforms/configuration_transforms/graphs/kim_driver_graph_data.cpp b/kliff/transforms/configuration_transforms/graphs/kim_driver_graph_data.cpp index dea9ff1..00a9646 100644 --- a/kliff/transforms/configuration_transforms/graphs/kim_driver_graph_data.cpp +++ b/kliff/transforms/configuration_transforms/graphs/kim_driver_graph_data.cpp @@ -41,6 +41,7 @@ struct GraphData{ py::array_t images; // periodic images of the atoms py::array_t species; // species index of the atoms py::array_t z; // atomic number of the atoms + py::array_t cell; // cell of the system py::array_t contributions; // contributing of the atoms to the energy }; @@ -260,6 +261,9 @@ GraphData get_complete_graph( = py::array_t(n_atoms + n_pad, need_neighbors_64.data()); gs.n_nodes = n_atoms + n_pad; + + gs.cell = py::array_t(cell.size(), cell.data()); + delete[] padded_coords; delete[] need_neighbors; for (int i = 0; i < n_graph_layers; i++) { delete[] graph_edge_indices[i]; } @@ -282,6 +286,7 @@ PYBIND11_MODULE(graph_module, m) .def_readwrite("images", &GraphData::images) .def_readwrite("species", &GraphData::species) .def_readwrite("z", &GraphData::z) + .def_readwrite("cell", &GraphData::cell) .def_readwrite("contributions", &GraphData::contributions); m.def("get_complete_graph", &get_complete_graph,