Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Apr 5, 2024
1 parent de9ad94 commit 6930533
Showing 1 changed file with 16 additions and 35 deletions.
51 changes: 16 additions & 35 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging
import pathlib
import typing as t
from typing import Optional

import ase.io
import pandas as pd
import yaml
import zntrack.utils
from jax import config
from zntrack import dvc, zn

from apax.md import ASECalculator
from apax.md.function_transformations import available_transformations
Expand Down Expand Up @@ -36,36 +35,24 @@ class Apax(zntrack.Node):
"""

data: list = zntrack.deps()
config: str = dvc.params("apax.yaml")
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: Optional[t.Any] = zntrack.deps(None)
model: t.Optional[t.Any] = zntrack.deps(None)

model_directory: pathlib.Path = dvc.outs(zntrack.nwd / "apax_model")
model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

train_data_file: pathlib.Path = dvc.outs(zntrack.nwd / "train_atoms.extxyz")
validation_data_file: pathlib.Path = dvc.outs(zntrack.nwd / "val_atoms.extxyz")
train_data_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "train_atoms.extxyz")
validation_data_file: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "val_atoms.extxyz"
)

jax_enable_x64: bool = zn.params(True)
# TODO: why is this an extra parameter?
jax_enable_x64: bool = zntrack.params(True)

# metrics_epoch = zntrack.plots_path(
# zntrack.nwd / "apax_model" / "log.csv",
# # template=STATIC_PATH / "y_log.json",
# # x="epoch",
# # x_label="epochs",
# # y="val_loss",
# # y_label="validation loss",
# )
# metrics = zn.metrics()
metrics = zntrack.metrics()

_parameter: dict = None

# def _post_init_(self):
# self.data = utils.helpers.get_deps_if_node(self.data, "atoms")
# self.validation_data = utils.helpers.get_deps_if_node(
# self.validation_data, "atoms"
# )
# self._handle_parameter_file()

def _post_load_(self) -> None:
self._handle_parameter_file()

Expand Down Expand Up @@ -95,15 +82,10 @@ def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter)

# def move_metrics(self):
# """Move the metrics to the correct directories for DVC"""
# path = self.model_directory / self.metrics_epoch.name
# shutil.move(path, self.metrics_epoch)

# def get_metrics_from_plots(self):
# """In addition to the plots write a model metric"""
# metrics_df = pd.read_csv(self.metrics_epoch)
# self.metrics = metrics_df.iloc[-1].to_dict()
def get_metrics_from_plots(self):
"""In addition to the plots write a model metric"""
metrics_df = pd.read_csv(self.model_directory / "log.csv")
self.metrics = metrics_df.iloc[-1].to_dict()

def run(self):
"""Primary method to run which executes all steps of the model training"""
Expand All @@ -114,8 +96,7 @@ def run(self):
ase.io.write(self.validation_data_file, self.validation_data)

self.train_model()
# self.move_metrics()
# self.get_metrics_from_plots()
self.get_metrics_from_plots()

def get_calculator(self, **kwargs):
"""Get an apax ase calculator"""
Expand Down

0 comments on commit 6930533

Please sign in to comment.