Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Feb 26, 2024
1 parent ef616d3 commit 4a16d25
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
14 changes: 8 additions & 6 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def dataset_from_dicts(
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

ds = tf.data.Dataset.from_tensor_slices((
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
))
ds = tf.data.Dataset.from_tensor_slices(
(
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
)
)
return ds


Expand Down
15 changes: 9 additions & 6 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from jax_md import partition, quantity, space
from matscipy.neighbours import neighbour_list
from tqdm import trange
from apax.data.initialization import RawDataset, initialize_dataset

from apax.data.initialization import RawDataset, initialize_dataset
from apax.model import ModelBuilder
from apax.train.checkpoints import check_for_ensemble, restore_parameters
from apax.utils import jax_md_reduced
Expand All @@ -31,7 +31,6 @@ def maybe_vmap(apply, params):

def build_energy_neighbor_fns(atoms, config, params, dr_threshold, neigbor_from_jax):
r_max = config.model.r_max
atomic_numbers = jnp.asarray(atoms.numbers)
box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
neigbor_from_jax = neighbor_calculable_with_jax(box, r_max)
box = box.T
Expand Down Expand Up @@ -93,9 +92,9 @@ def unpack_results(results, inputs):
unpacked_results = []
for i in range(n_structures):
single_results = jax.tree_map(lambda x: x[i], results)
for k,v in single_results.items():
for k, v in single_results.items():
if "forces" in k:
single_results[k] = v[:inputs["n_atoms"][i]]
single_results[k] = v[: inputs["n_atoms"][i]]
unpacked_results.append(single_results)
return unpacked_results

Expand Down Expand Up @@ -234,15 +233,19 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
def batch_eval(self, data, batch_size=64, silent=False):
if self.model is None:
self.initialize(data[0])
dataset = initialize_dataset(self.model_config, RawDataset(atoms_list=data), calc_stats=False)
dataset = initialize_dataset(
self.model_config, RawDataset(atoms_list=data), calc_stats=False
)
dataset.set_batch_size(batch_size)

evaluated_data = []
n_data = dataset.n_data
ds = dataset.batch()
batched_model = jax.jit(jax.vmap(self.model, in_axes=(0, 0, 0, 0, 0)))

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True, disable=silent)
pbar = trange(
n_data, desc="Computing features", ncols=100, leave=True, disable=silent
)
for i, (inputs, _) in enumerate(ds):
positions_b, Z_b, neighbor_b, box_b, offsets_b = (
inputs["positions"],
Expand Down
10 changes: 6 additions & 4 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

epoch_metrics.update({**epoch_loss})

Expand Down

0 comments on commit 4a16d25

Please sign in to comment.