Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.3.0 changes #223

Merged
merged 24 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2b49283
added model loading edge case for `experiment=""`
M-R-Schaefer Dec 20, 2023
2f44651
Merge pull request #216 from apax-hub/ips_compat
M-R-Schaefer Dec 21, 2023
ecd2759
switched model saving to absolute path
M-R-Schaefer Dec 27, 2023
bbbbd76
poetry update
M-R-Schaefer Dec 28, 2023
8e0cd03
fixed model loading test
M-R-Schaefer Dec 28, 2023
fc9a75c
fixed jaxmd NL not being reallocated when going from matscipy to jaxm…
M-R-Schaefer Jan 2, 2024
b19adfd
linting
M-R-Schaefer Jan 2, 2024
1c3e2d5
added fp64 version of clu `Average`
M-R-Schaefer Jan 2, 2024
9f21feb
added sketch of multi step jit
M-R-Schaefer Jan 2, 2024
93cb39d
applied fix to jaxMD
M-R-Schaefer Jan 8, 2024
9d712f4
added stress triL loss function
M-R-Schaefer Jan 8, 2024
210c06d
linting
M-R-Schaefer Jan 8, 2024
0e9a5cf
added n_jitted_steps to config
M-R-Schaefer Jan 8, 2024
f3d8a60
linting
M-R-Schaefer Jan 9, 2024
4ecd6f9
updated linting action versions
M-R-Schaefer Jan 9, 2024
76bdd87
linting
M-R-Schaefer Jan 9, 2024
e828e52
removed debug print statement
Jan 10, 2024
55b96b6
Merge pull request #218 from apax-hub/orbax_fix
M-R-Schaefer Jan 11, 2024
008c3a8
Merge branch 'dev' into epoch_jit
PythonFZ Jan 11, 2024
add18eb
Merge branch 'dev' into stress_til_loss
M-R-Schaefer Jan 15, 2024
2e6ec7a
Merge pull request #221 from apax-hub/stress_til_loss
M-R-Schaefer Jan 16, 2024
0266e44
Merge branch 'dev' into epoch_jit
M-R-Schaefer Jan 16, 2024
dae2002
Merge pull request #220 from apax-hub/epoch_jit
M-R-Schaefer Jan 16, 2024
2b51835
Merge branch 'main' into dev
M-R-Schaefer Jan 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.model_config, self.params = restore_parameters(model_dir)
self.n_models = check_for_ensemble(self.params)
self.padding_factor = padding_factor
self.padded_length = 0

if self.model_config.model.calc_stress:
self.implemented_properties.append("stress")
Expand Down Expand Up @@ -148,6 +149,10 @@ def initialize(self, atoms):
self.step = get_step_fn(model, atoms, self.neigbor_from_jax)
self.neighbor_fn = neighbor_fn

if self.neigbor_from_jax:
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
self.neighbors = self.neighbor_fn.allocate(positions)

def set_neighbours_and_offsets(self, atoms, box):
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)

Expand Down
2 changes: 1 addition & 1 deletion apax/md/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def body_fn(i, state):
)
ckpt = {"state": state, "step": step}
checkpoints.save_checkpoint(
ckpt_dir=ckpt_dir,
ckpt_dir=ckpt_dir.resolve(),
target=ckpt,
step=step,
overwrite=True,
Expand Down
4 changes: 2 additions & 2 deletions apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class CheckpointManager:
def __init__(self) -> None:
self.async_manager = checkpoints.AsyncManager()

def save_checkpoint(self, ckpt, epoch: int, path: str) -> None:
def save_checkpoint(self, ckpt, epoch: int, path: Path) -> None:
checkpoints.save_checkpoint(
ckpt_dir=path,
ckpt_dir=path.resolve(),
target=ckpt,
step=epoch,
overwrite=True,
Expand Down
Loading