Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kliff DNN torch trainer #185

Merged
merged 4 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions kliff/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .base_trainer import Trainer
from .kim_trainer import KIMTrainer
from .lightning_trainer import GNNLightningTrainer

# from .torch_trainer import DNNTrainer
from .torch_trainer import DNNTrainer
45 changes: 38 additions & 7 deletions kliff/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def __init__(self, training_manifest: dict, model=None):
"learning_rate": None,
"kwargs": None,
"epochs": 10000,
"stop_condition": None,
"num_workers": None,
"batch_size": 1,
}
Expand Down Expand Up @@ -281,9 +280,6 @@ def parse_manifest(self, manifest: dict):

self.optimizer_manifest |= self.training_manifest.get("optimizer")
self.optimizer_manifest["epochs"] = self.training_manifest.get("epochs", 10000)
self.optimizer_manifest["stop_condition"] = self.training_manifest.get(
"stop_condition", None
)
self.optimizer_manifest["num_workers"] = self.training_manifest.get(
"num_workers", None
)
Expand Down Expand Up @@ -357,8 +353,9 @@ def initialize(self):
# Step 5 - Set up the test and train datasets, based on the provided indices
self.setup_dataset_split()
logger.info(f"Train and validation datasets set up.")
# Step 6 - Set up the model
self.setup_model()
# Step 6 - Set up the model, if not provided
if not self.model:
self.setup_model()
logger.info(f"Model loaded.")
# Step 6.5 - Setup parameter transform
self.setup_parameter_transforms()
Expand Down Expand Up @@ -538,7 +535,6 @@ def setup_dataset_transforms(self):
ConfigurationClass = ConfigurationClass(
**kwargs, copy_to_config=False
)

self.configuration_transform = ConfigurationClass

def setup_model(self):
Expand Down Expand Up @@ -720,6 +716,41 @@ def _generate_kim_cmake(model_name: str, driver_name: str, file_list: List) -> s
"""
return cmake

def write_training_env_edn(self, path: str):
"""
Generate the training_env.edn file for the KIM API. This file will be used to
accurately determine the training environment . The file will be saved in the current run directory.
It saves the hash of the configuration, and list of all python dependencies from
pip freeze.
"""
env_file = f"{path}/training_env.edn"
hash = self.get_trainer_hash()
with open(env_file, "w") as f:
try:
from pip._internal.operations.freeze import freeze

from kliff import __version__
except ImportError:
logger.warning(
"Could not import kliff version or pip freeze. Skipping."
)
return
python_env = []
for module in list(freeze()):
if "@" in module:
module = module.split("@")[0]
python_env.append(module)

f.write("{\n")
f.write(f'"kliff-version" "{__version__}"\n')
f.write(f'"trainer-used" "{type(self).__name__}"\n')
f.write(f'"manifest-hash" "{hash}"\n')
f.write(f'"python-dependencies" [\n')
for module in python_env:
f.write(f' "{module}"\n')
f.write(f"]\n")
f.write("}\n")


# Parallel processing for dataset loading #############################################
def _parallel_read(
Expand Down
4 changes: 4 additions & 0 deletions kliff/trainer/kim_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,13 @@ def save_kim_model(self):
/ self.export_manifest["model_name"]
)
self.model.write_kim_model(path)
self.write_training_env_edn(path)
if self.export_manifest["generate_tarball"]:
tarfile_path = path.with_suffix(".tar.gz")
with tarfile.open(tarfile_path, "w:gz") as tar:
tar.add(path, arcname=path.name)
logger.info(f"Model tarball saved: {tarfile_path}")
logger.info(f"KIM model saved at {path}")


# TODO: Support for lst_sq in optimizer
6 changes: 6 additions & 0 deletions kliff/trainer/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(self, manifest, model=None):
self.callbacks = self._get_callbacks()

# setup lightning trainer
self.setup_model() # call setup_model explicitly as it converty torch -> lightning
self.pl_trainer = self._get_pl_trainer()

def setup_model(self):
Expand All @@ -293,6 +294,7 @@ def setup_model(self):
ema = True if self.optimizer_manifest.get("ema", False) else False
if ema:
ema_decay = self.optimizer_manifest.get("ema_decay", 0.99)
logger.info(f"Using Exponential Moving Average with decay rate {ema_decay}")
else:
ema_decay = None

Expand All @@ -312,6 +314,7 @@ def setup_model(self):
lr_scheduler=scheduler.get("name", None),
lr_scheduler_args=scheduler.get("args", None),
)
logger.info("Lightning Model setup complete.")

def train(self):
"""
Expand Down Expand Up @@ -513,6 +516,9 @@ def save_kim_model(self, path: str = "kim-model"):
with open(f"{path}/CMakeLists.txt", "w") as f:
f.write(cmakefile)

# write training environment
self.write_training_env_edn(path)

if self.export_manifest["generate_tarball"]:
tarball_name = f"{path}.tar.gz"
with tarfile.open(tarball_name, "w:gz") as tar:
Expand Down
Loading
Loading