diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index aca2e52d..0bfc4020 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -17,7 +17,7 @@ from apax.utils import jax_md_reduced -def maybe_vmap(apply, params, Z): +def maybe_vmap(apply, params): n_models = check_for_ensemble(params) if n_models > 1: @@ -54,7 +54,6 @@ def build_energy_neighbor_fns(atoms, config, params, dr_threshold, neigbor_from_ format=partition.Sparse, ) - Z = jnp.asarray(atomic_numbers) n_species = 119 # int(np.max(Z) + 1) builder = ModelBuilder(config.model.get_dict(), n_species=n_species) @@ -62,8 +61,7 @@ def build_energy_neighbor_fns(atoms, config, params, dr_threshold, neigbor_from_ apply_mask=True, init_box=np.array(box), inference_disp_fn=displacement_fn ) - energy_fn = maybe_vmap(model.apply, params, Z) - + energy_fn = maybe_vmap(model.apply, params) return energy_fn, neighbor_fn @@ -91,17 +89,17 @@ def ensemble(positions, Z, idx, box, offsets): def unpack_results(results, inputs): + n_structures = len(results["energy"]) unpacked_results = [] - - unpacked_results = jax.tree_transpose( - outer_treedef = jax.tree_structure(results), - inner_treedef = jax.tree_structure([0 for r in results["energy"]]), - pytree_to_transpose = results - ) + for i in range(n_structures): + single_results = jax.tree_map(lambda x: x[i], results) + for k,v in single_results.items(): + if "forces" in k: + single_results[k] = v[:inputs["n_atoms"][i]] + unpacked_results.append(single_results) return unpacked_results - class ASECalculator(Calculator): """ ASE Calculator for apax models. @@ -239,19 +237,25 @@ def batch_eval(self, data, batch_size=64, silent=False): dataset = initialize_dataset(self.model_config, RawDataset(atoms_list=data), calc_stats=False) dataset.set_batch_size(batch_size) - # features = [] evaluated_data = [] n_data = dataset.n_data ds = dataset.batch() - batched_model = jax.jit(jax.vmap(self.model, in_axes=(None, 0, 0, 0, 0, 0))) + batched_model = jax.jit(jax.vmap(self.model, in_axes=(0, 0, 0, 0, 0))) pbar = trange(n_data, desc="Computing features", ncols=100, leave=True, disable=silent) for i, (inputs, _) in enumerate(ds): - results = batched_model(inputs) + positions_b, Z_b, neighbor_b, box_b, offsets_b = ( + inputs["positions"], + inputs["numbers"], + inputs["idx"], + inputs["box"], + inputs["offsets"], + ) + results = batched_model(positions_b, Z_b, neighbor_b, box_b, offsets_b) unpadded_results = unpack_results(results, inputs) for j in range(batch_size): atoms = data[i].copy() - atoms.calc = SinglePointCalculator(atoms=atoms, results=unpadded_results[j]) + atoms.calc = SinglePointCalculator(atoms=atoms, **unpadded_results[j]) evaluated_data.append(atoms) pbar.update(batch_size) pbar.close()