From 87b37ec66cc8853c58698953524e630ceb19fe4b Mon Sep 17 00:00:00 2001 From: Cagri Kaymak Date: Sun, 11 Feb 2024 08:17:37 -0500 Subject: [PATCH] add torsion restraint support --- jaxreaxff/optimizer.py | 75 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/jaxreaxff/optimizer.py b/jaxreaxff/optimizer.py index 122cbc3..ced18ed 100644 --- a/jaxreaxff/optimizer.py +++ b/jaxreaxff/optimizer.py @@ -50,9 +50,19 @@ def calculate_bond_restraint_energy(positions, structure): (1.0 - jnp.exp(-f2s * (cur_dists - targets)**2))) return rest_pot +def angle_difference(angle1, angle2): + ''' + Calculate angle difference between 2 angles + while respecting periodicity + Ex. diff between 170 and -170 is 20 degree. + ''' + diff = jnp.mod(angle1 - angle2, 360) + diff = jnp.where(diff < 180, diff, 360 - diff) + return diff + def calculate_angle_restraint_energy(positions, structure): ''' - Calculate bond restraint potential + Calculate angle restraint potential Erestraint= Force1*{1.0-exp(Force2*(angle-target_angle^2} ''' @@ -77,16 +87,66 @@ def calculate_angle_restraint_energy(positions, structure): jnp.arccos, cur_angle).astype(disp12.dtype) # convert it to degree cur_angle = cur_angle * rdndgr - # to have periodicity, Ex. diff between 170 and -170 is 20 degree. - cur_angle = jnp.where(cur_angle < 0.0, cur_angle + 360.0, cur_angle) - target = jnp.where(target < 0.0, target+360.0, target) - diff = (cur_angle - target) * dgrrdn + # calculate the difference + diff = angle_difference(cur_angle, target) * dgrrdn rest_pot = jnp.sum(mask * f1s * (1.0 - jnp.exp(-f2s * (diff)**2))) return rest_pot +def calculate_torsion_angle(p1, p2, p3, p4): + """ + Calculate the torsion angle + [1-(2-3)-4] ---- (2-3 is the center) + Praxeolitic formula + Taken from: https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python + """ + + b0 = -1.0*(p2 - p1) + b1 = p3 - p2 + b2 = p4 - p3 + + # normalize b1 so that it does not influence magnitude of vector + # rejections that come next + b1 /= jnp.linalg.norm(b1 + 1e-10) + + # vector rejections + # v = projection of b0 onto plane perpendicular to b1 + # = b0 minus component that aligns with b1 + # w = projection of b2 onto plane perpendicular to b1 + # = b2 minus component that aligns with b1 + v = b0 - jnp.dot(b0, b1)*b1 + w = b2 - jnp.dot(b2, b1)*b1 + + # angle between v and w in a plane is the torsion angle + # v and w may not be normalized but that's fine since tan is y/x + x = jnp.dot(v, w) + y = jnp.dot(jnp.cross(b1, v), w) + return jnp.degrees(jnp.arctan2(y, x+1e-10)) + def calculate_torsion_restraint_energy(positions, structure): - pass + ''' + Calculate torsion restraint potential + Erestraint = Force1*{1.0-exp(Force2*(angle-target_angle^2} + ''' + + torsion_restraints = structure.torsion_restraints + + ind1s = torsion_restraints.ind1 + ind2s = torsion_restraints.ind2 + ind3s = torsion_restraints.ind3 + ind4s = torsion_restraints.ind4 + target = torsion_restraints.target + f1s = torsion_restraints.force1 + f2s = torsion_restraints.force2 + mask = ind1s != -1 + #TODO: The angle calculation expects both atoms to be in the center box, + # does not work when it crosses periodic boundary + cur_angle = jax.vmap(calculate_torsion_angle)(positions[ind1s], positions[ind2s], + positions[ind3s], positions[ind4s]) + # calculate the difference + diff = angle_difference(cur_angle, target) * dgrrdn + rest_pot = jnp.sum(mask * f1s * (1.0 - jnp.exp(-f2s * (diff)**2))) + return rest_pot def calculate_energy_and_charges(positions, @@ -132,8 +192,9 @@ def calculate_energy_and_charges_w_rest(positions, force_field) bond_rest_en = calculate_bond_restraint_energy(positions, structure) angle_rest_en = calculate_angle_restraint_energy(positions, structure) + torsion_rest_en = calculate_torsion_restraint_energy(positions, structure) - energy = energy + bond_rest_en + angle_rest_en + energy = energy + bond_rest_en + angle_rest_en + torsion_rest_en return energy, charges def calculate_loss(force_field,