Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed jax nl always allocating in cartesian coords #229

Merged
merged 6 commits into from
Feb 2, 2024
Merged
25 changes: 12 additions & 13 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,23 @@ def initialize(self, atoms):
self.neighbor_fn = neighbor_fn

if self.neigbor_from_jax:
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
self.neighbors = self.neighbor_fn.allocate(positions)
if np.any(atoms.get_cell().lengths() > 1e-6):
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
box = atoms.cell.array.T
inv_box = jnp.linalg.inv(box)
positions = space.transform(inv_box, positions) # frac coords
self.neighbors = self.neighbor_fn.allocate(positions, box=box)
else:
self.neighbors = self.neighbor_fn.allocate(positions)
else:
idxs_i = neighbour_list("i", atoms, self.r_max)
self.padded_length = int(len(idxs_i) * self.padding_factor)

def set_neighbours_and_offsets(self, atoms, box):
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)

if len(idxs_i) > self.padded_length:
print("neighbor list overflowed, reallocating.")
print("neighbor list overflowed, extending.")
self.padded_length = int(len(idxs_i) * self.padding_factor)
self.initialize(atoms)

Expand All @@ -178,12 +187,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
if self.step is None:
self.initialize(atoms)

if self.neigbor_from_jax:
self.neighbors = self.neighbor_fn.allocate(positions)
else:
idxs_i = neighbour_list("i", atoms, self.r_max)
self.padded_length = int(len(idxs_i) * self.padding_factor)

elif "numbers" in system_changes:
self.initialize(atoms)

Expand All @@ -202,8 +205,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
if self.neighbors.did_buffer_overflow:
print("neighbor list overflowed, reallocating.")
self.initialize(atoms)
self.neighbors = self.neighbor_fn.allocate(positions)

results, self.neighbors = self.step(positions, self.neighbors, box)

else:
Expand Down Expand Up @@ -267,10 +268,8 @@ def step_fn(positions, neighbor, box):
@jax.jit
def step_fn(positions, neighbor, box, offsets):
results = model(positions, Z, neighbor, box, offsets)

if "stress" in results.keys():
results = process_stress(results, box)

return results

return step_fn
Loading