Skip to content

Commit

Permalink
Merge branch 'nl_allocate_fix' of https://github.com/apax-hub/apax in…
Browse files Browse the repository at this point in the history
…to nl_allocate_fix
  • Loading branch information
M-R-Schaefer committed Feb 1, 2024
2 parents 9aefa7d + 63066b5 commit e097ce4
Show file tree
Hide file tree
Showing 28 changed files with 137 additions and 252 deletions.
48 changes: 0 additions & 48 deletions .github/workflows/linting.yaml

This file was deleted.

12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@ fail_fast: true

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.1.1
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://gitlab.com/pycqa/flake8
rev: 5.0.1
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies: [ flake8-isort ]
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ build:

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
configuration: docs/source/conf.py
2 changes: 1 addition & 1 deletion apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>

n_train: 1000
n_valid: 100

Expand Down
14 changes: 6 additions & 8 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ def dataset_from_dicts(
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

ds = tf.data.Dataset.from_tensor_slices(
(
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
)
)
ds = tf.data.Dataset.from_tensor_slices((
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
))
return ds


Expand Down
3 changes: 2 additions & 1 deletion apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def initialize(self, atoms):
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
box = atoms.cell.array.T
inv_box = jnp.linalg.inv(box)
positions = space.transform(inv_box, positions) # frac coords
positions = space.transform(inv_box, positions) # frac coords
self.neighbors = self.neighbor_fn.allocate(positions, box=box)
else:
self.neighbors = self.neighbor_fn.allocate(positions)
Expand Down Expand Up @@ -264,6 +264,7 @@ def step_fn(positions, neighbor, box):
return results, neighbor

else:

@jax.jit
def step_fn(positions, neighbor, box, offsets):
results = model(positions, Z, neighbor, box, offsets)
Expand Down
14 changes: 4 additions & 10 deletions apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ def create_single_train_state(params):
return state

if n_models > 1:
train_state_fn = jax.vmap(
create_single_train_state,
axis_name="ensemble"
)
train_state_fn = jax.vmap(create_single_train_state, axis_name="ensemble")
else:
train_state_fn = create_single_train_state

Expand Down Expand Up @@ -129,9 +126,7 @@ def load_params(model_version_path: Path, best=True) -> FrozenDict:
try:
# keep try except block for zntrack load from rev
raw_restored = checkpoints.restore_checkpoint(
model_version_path,
target=None,
step=None
model_version_path, target=None, step=None
)
except FileNotFoundError:
print(f"No checkpoint found at {model_version_path}")
Expand All @@ -143,8 +138,7 @@ def load_params(model_version_path: Path, best=True) -> FrozenDict:


def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
"""Load the config and parameters of a single model
"""
"""Load the config and parameters of a single model"""
model_dir = Path(model_dir)
model_config = parse_config(model_dir / "config.yaml")

Expand Down Expand Up @@ -200,6 +194,6 @@ def canonicalize_energy_grad_model_parameters(params):

first_level = param_dict["params"]
if "energy_model" not in first_level.keys():
params = {"params": {"energy_model" : first_level}}
params = {"params": {"energy_model": first_level}}
params = freeze(params)
return params
10 changes: 4 additions & 6 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,10 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)
epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})

epoch_metrics.update({**epoch_loss})

Expand Down
2 changes: 1 addition & 1 deletion apax/utils/jax_md_reduced/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ We would like to thank the developers of `jax_md` for the work on this great pac
volume = {33},
year = {2020}
}
```
```
2 changes: 1 addition & 1 deletion docs/source/_tutorials/md_with_ase.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ An ASE calculator of a trained model can be instantiated as follows

CODE

Please refer to the ASE documentation LINK to see how to use ASE calculators.
Please refer to the ASE documentation LINK to see how to use ASE calculators.
5 changes: 1 addition & 4 deletions docs/source/_tutorials/molecular_dynamics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ Congratulations, you have calculated the first observable from a trajectory gene

## Custom Simulation Loops

More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples).
More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples).
Trained apax models can of course be used as `energy_fn` in such custom simulations.
If you have a suggestion for adding some MD feature or thermostat to the core of `apax`, feel free to open up an issue on Github LINK.



2 changes: 1 addition & 1 deletion docs/source/_tutorials/training_a_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ We provide a separate command for test set evaluation:

TODO pretty print results to the terminal

Congratulations, you have successfully trained and evaluated your fitrst apax model!
Congratulations, you have successfully trained and evaluated your fitrst apax model!
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@

html_theme = "furo"

html_theme_options = {}
html_theme_options = {}
2 changes: 1 addition & 1 deletion docs/source/getting_started/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Getting Started
.. toctree::
:maxdepth: 2

install
install
2 changes: 1 addition & 1 deletion docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ If you want to enable GPU support, please overwrite the jaxlib version:
See the `Jax installation instructions <https://github.com/google/jax#installation>`_ for more details.


.. _Poetry: https://python-poetry.org/
.. _Poetry: https://python-poetry.org/
2 changes: 1 addition & 1 deletion docs/source/modules/md.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ Molecular Dynamics
:members:

.. automodule:: apax.md.nvt
:members:
:members:
2 changes: 1 addition & 1 deletion docs/source/modules/optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ Optimizers
==========

.. automodule:: apax.optimizer.get_optimizer
:members:
:members:
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "apax"
version = "0.2.1"
version = "0.3.0"
description = "Atomistic Learned Potential Package in JAX"
authors = ["Moritz René Schäfer <[email protected]>", "Nico Segreto <[email protected]>"]
keywords=["machine-learning", "interatomic potentials", "molecular-dynamics"]
Expand Down Expand Up @@ -76,4 +76,3 @@ directory = "coverage_html_report"

[tool.coverage.report]
show_missing = true

12 changes: 5 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,11 @@ def modify_xyz_file(file_path, target_string, replacement_string):

@pytest.fixture()
def get_sample_input():
positions = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
positions = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
atomic_numbers = np.array([1, 1, 8])
box = np.diag(np.zeros(3))
offsets = np.full([3, 3], 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/bal/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ metrics:

loss:
- name: energy
- name: forces
- name: forces
2 changes: 1 addition & 1 deletion tests/integration_tests/md/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ metrics:

loss:
- name: energy
- name: forces
- name: forces
2 changes: 1 addition & 1 deletion tests/integration_tests/md/md_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ duration: 0.25
n_inner: 1
sampling_rate: 1
checkpoint_interval: 2
restart: True
restart: True
12 changes: 5 additions & 7 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,11 @@ def test_ase_calc(get_tmp_path):
model_config.dump_config(model_config.data.model_version_path)

cell_size = 10.0
positions = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
positions = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
atomic_numbers = np.array([1, 1, 8])
box = np.diag([cell_size] * 3)
offsets = jnp.full([3, 3], 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/transfer_learning/config_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ optimizer:
emb_lr: 0.001
nn_lr: 0.001
scale_lr: 0.0001
shift_lr: 0.001
shift_lr: 0.001
2 changes: 1 addition & 1 deletion tests/integration_tests/transfer_learning/config_ft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ optimizer:
shift_lr: 0.001

checkpoints:
base_model_checkpoint: null
base_model_checkpoint: null
Loading

0 comments on commit e097ce4

Please sign in to comment.