Skip to content

Commit

Permalink
Modified dataset weights + tests to reflect
Browse files Browse the repository at this point in the history
  • Loading branch information
ipcamit committed Jun 3, 2024
1 parent 26edfce commit 289a2a9
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 43 deletions.
18 changes: 12 additions & 6 deletions kliff/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def add_from_colabfit(
)
else:
configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, None)
self.add_weights(weight)
Dataset.add_weights(configs, weight)
self.configs.extend(configs)

@classmethod
Expand Down Expand Up @@ -800,7 +800,7 @@ def add_from_path(
configs = self._read_from_path(path, weight, file_format)
else:
configs = self._read_from_path(path, None, file_format)
self.add_weights(weight)
Dataset.add_weights(configs, weight)
self.configs.extend(configs)

@classmethod
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def add_from_ase(
configs = self._read_from_ase(
path, ase_atoms_list, None, energy_key, forces_key, slices, file_format
)
self.add_weights(weight)
Dataset.add_weights(configs, weight)
self.configs.extend(configs)

def get_configs(self) -> List[Configuration]:
Expand Down Expand Up @@ -1064,7 +1064,8 @@ def save_weights(self, path: Union[Path, str]):
+ f"{config.weight.stress_weight}\n"
)

def add_weights(self, path: Union[Path, str]):
@staticmethod
def add_weights(configurations: List[Configuration], path: Union[Path, str]):
"""
Load weights from a text file. The text file should contain 1 to 4 columns,
whitespace seperated, formatted as,
Expand All @@ -1086,6 +1087,7 @@ def add_weights(self, path: Union[Path, str]):
```
Args:
configurations: List of configurations to add weights to.
path: Path to the configuration file
"""
Expand All @@ -1103,7 +1105,7 @@ def add_weights(self, path: Union[Path, str]):
"there needs to be at least 1 col, and at most 4"
)

if not (weights_data.size == 1 or weights_data.size == len(self)):
if not (weights_data.size == 1 or weights_data.size == len(configurations)):
raise DatasetError(
"Weights file contains improper number of rows,"
"there can be either 1 row (all weights same), "
Expand All @@ -1118,8 +1120,12 @@ def add_weights(self, path: Union[Path, str]):
for fields in missing_cols:
weights[fields] = np.zeros_like(weights["config"])

# if only one row, set same weight for all
if weights_data.size == 1:
weights = {k: np.full(len(configurations), v) for k, v in weights.items()}

# set weights
for i, config in enumerate(self.configs):
for i, config in enumerate(configurations):
config.weight = Weight(
config_weight=weights["config"][i],
energy_weight=weights["energy"][i],
Expand Down
3 changes: 3 additions & 0 deletions kliff/dataset/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def stress_weight(self):
def stress_weight(self, value):
self._stress_weight = value

def __repr__(self):
return f"Weights: config={self.config_weight}, energy={self.energy_weight}, forces={self.forces_weight}, stress={self.stress_weight}"

def _check_compute_flag(self, config):
"""
Check whether compute flag correctly set when the corresponding weight in
Expand Down
37 changes: 23 additions & 14 deletions kliff/trainer/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
else:
from torch_geometric.data.lightning import LightningDataset

from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from kliff.dataset import Dataset
from kliff.utils import get_n_configs_in_xyz
Expand All @@ -39,9 +39,13 @@

import hashlib

from .torch_trainer_utils.lightning_checkpoints import SaveModelCallback, LossTrajectoryCallback
from pytorch_lightning.callbacks import EarlyStopping

from .torch_trainer_utils.lightning_checkpoints import (
LossTrajectoryCallback,
SaveModelCallback,
)


class LightningTrainerWrapper(pl.LightningModule):
"""
Expand Down Expand Up @@ -129,14 +133,15 @@ def configure_optimizers(self):
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler)(
optimizer, **self.lr_scheduler_args
)
return {"optimizer": optimizer,
"lr_scheduler":
{"scheduler": self.lr_scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "val_loss"
}
}
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": self.lr_scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "val_loss",
},
}
else:
return optimizer

Expand All @@ -155,9 +160,11 @@ def validation_step(self, batch, batch_idx):
(predicted_forces.squeeze() - target_forces.squeeze()) ** 2, dim=1
)

loss = energy_weight * F.mse_loss(
predicted_energy.squeeze(), target_energy.squeeze()
) + forces_weight * torch.mean(per_atom_force_loss)/3 # divide by 3 to get correct MSE
loss = (
energy_weight
* F.mse_loss(predicted_energy.squeeze(), target_energy.squeeze())
+ forces_weight * torch.mean(per_atom_force_loss) / 3
) # divide by 3 to get correct MSE

self.log(
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
Expand Down Expand Up @@ -254,7 +261,9 @@ def setup_dataloaders(self):
logger.info("Data modules setup complete.")

def _tb_logger(self):
return TensorBoardLogger(f"{self.current['run_dir']}/logs", name="lightning_logs")
return TensorBoardLogger(
f"{self.current['run_dir']}/logs", name="lightning_logs"
)

def _csv_logger(self):
return CSVLogger(f"{self.current['run_dir']}/logs", name="csv_logs")
Expand Down
43 changes: 33 additions & 10 deletions kliff/trainer/torch_trainer_utils/lightning_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import pytorch_lightning as pl
import os
from kliff.dataset import Dataset

import dill
import pytorch_lightning as pl
import torch

from kliff.dataset import Dataset


class SaveModelCallback(pl.Callback):
Expand All @@ -11,14 +13,17 @@ class SaveModelCallback(pl.Callback):
"last_model.pth". The best model is saved with the name "best_model.pth". The model is saved every
ckpt_interval epochs with the name "epoch_{epoch}.pth".
"""

def __init__(self, ckpt_dir, ckpt_interval=100):
super().__init__()
self.ckpt_dir = ckpt_dir
self.best_val_loss = float("inf")
self.ckpt_interval = ckpt_interval
os.makedirs(self.ckpt_dir, exist_ok=True)

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
# Save the last model
last_save_path = os.path.join(self.ckpt_dir, "last_model.pth")
torch.save(pl_module.state_dict(), last_save_path)
Expand All @@ -31,7 +36,9 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo

# Save the model every ckpt_interval epochs
if pl_module.current_epoch % self.ckpt_interval == 0:
epoch_save_path = os.path.join(self.ckpt_dir, f"epoch_{pl_module.current_epoch}.pth")
epoch_save_path = os.path.join(
self.ckpt_dir, f"epoch_{pl_module.current_epoch}.pth"
)
torch.save(pl_module.state_dict(), epoch_save_path)


Expand All @@ -40,6 +47,7 @@ class LossTrajectoryCallback(pl.Callback):
Callback to save the loss trajectory of the model during validation. The loss trajectory is saved in the
loss_traj_file. The loss trajectory is saved every ckpt_interval epochs. Currently, it only logs per atom force loss.
"""

def __init__(self, loss_traj_folder, val_dataset: Dataset, ckpt_interval=10):
super().__init__()
self.loss_traj_folder = loss_traj_folder
Expand All @@ -49,18 +57,33 @@ def __init__(self, loss_traj_folder, val_dataset: Dataset, ckpt_interval=10):
with open(os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "w") as f:
f.write("epoch,loss\n")

dill.dump(val_dataset, open(os.path.join(self.loss_traj_folder, "val_dataset.pkl"), "wb"))
dill.dump(
val_dataset,
open(os.path.join(self.loss_traj_folder, "val_dataset.pkl"), "wb"),
)
self.val_losses = []

def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0):
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
):
if trainer.current_epoch % self.ckpt_interval == 0:
val_force_loss = outputs["per_atom_force_loss"].detach().cpu().numpy()
self.val_losses.extend(val_force_loss)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.current_epoch % self.ckpt_interval == 0:
with open(os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "a") as f:
loss_str = ",".join([str(trainer.current_epoch)] + [f"{loss:.5f}" for loss in self.val_losses])
with open(
os.path.join(self.loss_traj_folder, "loss_traj_idx.csv"), "a"
) as f:
loss_str = ",".join(
[str(trainer.current_epoch)]
+ [f"{loss:.5f}" for loss in self.val_losses]
)
f.write(f"{loss_str}\n")
self.val_losses = []

116 changes: 107 additions & 9 deletions tests/dataset/test_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from kliff.dataset import Dataset
from kliff.dataset.dataset import DatasetError
from kliff.dataset.weight import MagnitudeInverseWeight, Weight

np.random.seed(2022)
Expand Down Expand Up @@ -88,19 +89,116 @@ def _compute_magnitude_inverse_weight(c1, c2, norm):
return 1 / sigma


# tests for loading weights from a file
def test_weight_from_file():
"""Load 4 weights from a file"""
xyz_file = Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz")
weight_file = Path(__file__).parents[1].joinpath("test_data/weights/weights_4.dat")
ds = Dataset.from_ase(
Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz"),
xyz_file,
energy_key="Energy",
forces_key="force",
weight=Path(__file__).parents[1].joinpath("test_data/weights/Si_4_weights.dat"),
weight=weight_file,
)
configs = ds.get_configs()
assert len(configs) == 4
assert configs[0].weight.config_weight == 1.0
assert configs[0].weight.energy_weight == 0.5
assert configs[0].weight.stress_weight == 1.0
assert configs[3].weight.forces_weight == 0.5
assert configs[3].weight.config_weight == 1.0
assert configs[3].weight.energy_weight == 0.5
assert configs[3].weight.stress_weight == 4.0
weights = np.genfromtxt(weight_file, names=True)

config_weights = weights["Config"]
energy_weights = weights["Energy"]
forces_weights = weights["Forces"]
stress_weights = weights["Stress"]

assert configs[0].weight.config_weight == config_weights[0]
assert configs[0].weight.energy_weight == energy_weights[0]
assert configs[0].weight.forces_weight == forces_weights[0]
assert configs[0].weight.stress_weight == stress_weights[0]
assert configs[3].weight.config_weight == config_weights[3]
assert configs[3].weight.energy_weight == energy_weights[3]
assert configs[3].weight.forces_weight == forces_weights[3]
assert configs[3].weight.stress_weight == stress_weights[3]


def test_single_weight_from_file():
"""Load a single weight from a file"""
xyz_file = Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz")
weight_file = Path(__file__).parents[1].joinpath("test_data/weights/weights_1.dat")
ds = Dataset.from_ase(
xyz_file,
energy_key="Energy",
forces_key="force",
weight=weight_file,
)
# all weights should be the same
configs = ds.get_configs()

weights = np.genfromtxt(weight_file, names=True)
config_weight = weights["Config"]
energy_weight = weights["Energy"]
forces_weight = weights["Forces"]
stress_weight = weights["Stress"]

assert len(configs) == 4
assert configs[0].weight.config_weight == config_weight
assert configs[0].weight.energy_weight == energy_weight
assert configs[0].weight.forces_weight == forces_weight
assert configs[0].weight.stress_weight == stress_weight
assert configs[3].weight.config_weight == config_weight
assert configs[3].weight.energy_weight == energy_weight
assert configs[3].weight.forces_weight == forces_weight
assert configs[3].weight.stress_weight == stress_weight


def test_incomplete_weights_from_file():
"""Load 3 weights from a file, this test should fail, with DatasetError, any other error is a failure of the test."""
xyz_file = Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz")
weight_file = (
Path(__file__).parents[1].joinpath("test_data/weights/weights_4_incomplete.dat")
)
try:
ds = Dataset.from_ase(
xyz_file,
energy_key="Energy",
forces_key="force",
weight=weight_file,
)
except DatasetError:
assert True
except:
assert False, "Wrong expected Exception raised"
else:
assert False, "Expected Exception not raised"


def test_minimal_weights_from_file():
"""Load 2 weights from a file"""
xyz_file = Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz")
weight_file = (
Path(__file__).parents[1].joinpath("test_data/weights/weights_4_partial.dat")
)

ds = Dataset.from_ase(
xyz_file,
energy_key="Energy",
forces_key="force",
weight=weight_file,
)
configs = ds.get_configs()
assert len(configs) == 4
weights = np.genfromtxt(weight_file, names=True)

assert len(weights.dtype.names) == 2
assert weights.dtype.names == ("Config", "Forces")

config_weights = weights["Config"]
forces_weights = weights["Forces"]

assert configs[0].weight.config_weight == config_weights[0]
assert configs[0].weight.energy_weight == 0.0
assert configs[0].weight.forces_weight == forces_weights[0]
assert configs[0].weight.stress_weight == 0.0

assert configs[3].weight.config_weight == config_weights[3]
assert configs[3].weight.energy_weight == 0.0
assert configs[3].weight.forces_weight == forces_weights[3]
assert configs[3].weight.stress_weight == 0.0
4 changes: 0 additions & 4 deletions tests/test_data/weights/Si_4_weights.dat

This file was deleted.

2 changes: 2 additions & 0 deletions tests/test_data/weights/weights_1.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Config Forces Energy Stress
2.0 20.0 200.0 0.0
5 changes: 5 additions & 0 deletions tests/test_data/weights/weights_4.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Config Forces Energy Stress
2.0 20.0 200.0 0.0
1.0 10.0 100.0 10.0
0.0 0.0 0.0 20.0
2.5 25.0 100.0 -10.0
3 changes: 3 additions & 0 deletions tests/test_data/weights/weights_4_incomplete.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Config Forces Energy Stress
2.0 20.0 200.0 1.0
1.0 10.0 100.0 10.0
Loading

0 comments on commit 289a2a9

Please sign in to comment.