Skip to content

Commit

Permalink
completed batch_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Feb 26, 2024
1 parent a57d06b commit 2830b4f
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from apax.utils import jax_md_reduced


def maybe_vmap(apply, params, Z):
def maybe_vmap(apply, params):
n_models = check_for_ensemble(params)

if n_models > 1:
Expand Down Expand Up @@ -54,16 +54,14 @@ def build_energy_neighbor_fns(atoms, config, params, dr_threshold, neigbor_from_
format=partition.Sparse,
)

Z = jnp.asarray(atomic_numbers)
n_species = 119 # int(np.max(Z) + 1)
builder = ModelBuilder(config.model.get_dict(), n_species=n_species)

model = builder.build_energy_derivative_model(
apply_mask=True, init_box=np.array(box), inference_disp_fn=displacement_fn
)

energy_fn = maybe_vmap(model.apply, params, Z)

energy_fn = maybe_vmap(model.apply, params)
return energy_fn, neighbor_fn


Expand Down Expand Up @@ -91,17 +89,17 @@ def ensemble(positions, Z, idx, box, offsets):


def unpack_results(results, inputs):
n_structures = len(results["energy"])
unpacked_results = []

unpacked_results = jax.tree_transpose(
outer_treedef = jax.tree_structure(results),
inner_treedef = jax.tree_structure([0 for r in results["energy"]]),
pytree_to_transpose = results
)
for i in range(n_structures):
single_results = jax.tree_map(lambda x: x[i], results)
for k,v in single_results.items():
if "forces" in k:
single_results[k] = v[:inputs["n_atoms"][i]]
unpacked_results.append(single_results)
return unpacked_results



class ASECalculator(Calculator):
"""
ASE Calculator for apax models.
Expand Down Expand Up @@ -239,19 +237,25 @@ def batch_eval(self, data, batch_size=64, silent=False):
dataset = initialize_dataset(self.model_config, RawDataset(atoms_list=data), calc_stats=False)
dataset.set_batch_size(batch_size)

# features = []
evaluated_data = []
n_data = dataset.n_data
ds = dataset.batch()
batched_model = jax.jit(jax.vmap(self.model, in_axes=(None, 0, 0, 0, 0, 0)))
batched_model = jax.jit(jax.vmap(self.model, in_axes=(0, 0, 0, 0, 0)))

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True, disable=silent)
for i, (inputs, _) in enumerate(ds):
results = batched_model(inputs)
positions_b, Z_b, neighbor_b, box_b, offsets_b = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
results = batched_model(positions_b, Z_b, neighbor_b, box_b, offsets_b)
unpadded_results = unpack_results(results, inputs)
for j in range(batch_size):
atoms = data[i].copy()
atoms.calc = SinglePointCalculator(atoms=atoms, results=unpadded_results[j])
atoms.calc = SinglePointCalculator(atoms=atoms, **unpadded_results[j])
evaluated_data.append(atoms)
pbar.update(batch_size)
pbar.close()
Expand Down

0 comments on commit 2830b4f

Please sign in to comment.