diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index 6422b5a5..3a69d2e6 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -150,14 +150,23 @@ def initialize(self, atoms): self.neighbor_fn = neighbor_fn if self.neigbor_from_jax: - positions = jnp.asarray(atoms.positions, dtype=jnp.float64) - self.neighbors = self.neighbor_fn.allocate(positions) + if np.any(atoms.get_cell().lengths() > 1e-6): + positions = jnp.asarray(atoms.positions, dtype=jnp.float64) + box = atoms.cell.array.T + inv_box = jnp.linalg.inv(box) + positions = space.transform(inv_box, positions) # frac coords + self.neighbors = self.neighbor_fn.allocate(positions, box=box) + else: + self.neighbors = self.neighbor_fn.allocate(positions) + else: + idxs_i = neighbour_list("i", atoms, self.r_max) + self.padded_length = int(len(idxs_i) * self.padding_factor) def set_neighbours_and_offsets(self, atoms, box): idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max) if len(idxs_i) > self.padded_length: - print("neighbor list overflowed, reallocating.") + print("neighbor list overflowed, extending.") self.padded_length = int(len(idxs_i) * self.padding_factor) self.initialize(atoms) @@ -178,12 +187,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes): if self.step is None: self.initialize(atoms) - if self.neigbor_from_jax: - self.neighbors = self.neighbor_fn.allocate(positions) - else: - idxs_i = neighbour_list("i", atoms, self.r_max) - self.padded_length = int(len(idxs_i) * self.padding_factor) - elif "numbers" in system_changes: self.initialize(atoms) @@ -202,8 +205,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes): if self.neighbors.did_buffer_overflow: print("neighbor list overflowed, reallocating.") self.initialize(atoms) - self.neighbors = self.neighbor_fn.allocate(positions) - results, self.neighbors = self.step(positions, self.neighbors, box) else: @@ -267,10 +268,8 @@ def step_fn(positions, neighbor, box): @jax.jit def step_fn(positions, neighbor, box, offsets): results = model(positions, Z, neighbor, box, offsets) - if "stress" in results.keys(): results = process_stress(results, box) - return results return step_fn