Skip to content

Commit

Permalink
add torsion restraint support
Browse files Browse the repository at this point in the history
  • Loading branch information
cagrikymk committed Feb 11, 2024
1 parent 68c6e46 commit 87b37ec
Showing 1 changed file with 68 additions and 7 deletions.
75 changes: 68 additions & 7 deletions jaxreaxff/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,19 @@ def calculate_bond_restraint_energy(positions, structure):
(1.0 - jnp.exp(-f2s * (cur_dists - targets)**2)))
return rest_pot

def angle_difference(angle1, angle2):
'''
Calculate angle difference between 2 angles
while respecting periodicity
Ex. diff between 170 and -170 is 20 degree.
'''
diff = jnp.mod(angle1 - angle2, 360)
diff = jnp.where(diff < 180, diff, 360 - diff)
return diff

def calculate_angle_restraint_energy(positions, structure):
'''
Calculate bond restraint potential
Calculate angle restraint potential
Erestraint= Force1*{1.0-exp(Force2*(angle-target_angle^2}
'''

Expand All @@ -77,16 +87,66 @@ def calculate_angle_restraint_energy(positions, structure):
jnp.arccos, cur_angle).astype(disp12.dtype)
# convert it to degree
cur_angle = cur_angle * rdndgr
# to have periodicity, Ex. diff between 170 and -170 is 20 degree.
cur_angle = jnp.where(cur_angle < 0.0, cur_angle + 360.0, cur_angle)
target = jnp.where(target < 0.0, target+360.0, target)
diff = (cur_angle - target) * dgrrdn
# calculate the difference
diff = angle_difference(cur_angle, target) * dgrrdn
rest_pot = jnp.sum(mask * f1s * (1.0 - jnp.exp(-f2s * (diff)**2)))

return rest_pot

def calculate_torsion_angle(p1, p2, p3, p4):
"""
Calculate the torsion angle
[1-(2-3)-4] ---- (2-3 is the center)
Praxeolitic formula
Taken from: https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python
"""

b0 = -1.0*(p2 - p1)
b1 = p3 - p2
b2 = p4 - p3

# normalize b1 so that it does not influence magnitude of vector
# rejections that come next
b1 /= jnp.linalg.norm(b1 + 1e-10)

# vector rejections
# v = projection of b0 onto plane perpendicular to b1
# = b0 minus component that aligns with b1
# w = projection of b2 onto plane perpendicular to b1
# = b2 minus component that aligns with b1
v = b0 - jnp.dot(b0, b1)*b1
w = b2 - jnp.dot(b2, b1)*b1

# angle between v and w in a plane is the torsion angle
# v and w may not be normalized but that's fine since tan is y/x
x = jnp.dot(v, w)
y = jnp.dot(jnp.cross(b1, v), w)
return jnp.degrees(jnp.arctan2(y, x+1e-10))

def calculate_torsion_restraint_energy(positions, structure):
pass
'''
Calculate torsion restraint potential
Erestraint = Force1*{1.0-exp(Force2*(angle-target_angle^2}
'''

torsion_restraints = structure.torsion_restraints

ind1s = torsion_restraints.ind1
ind2s = torsion_restraints.ind2
ind3s = torsion_restraints.ind3
ind4s = torsion_restraints.ind4
target = torsion_restraints.target
f1s = torsion_restraints.force1
f2s = torsion_restraints.force2
mask = ind1s != -1
#TODO: The angle calculation expects both atoms to be in the center box,
# does not work when it crosses periodic boundary
cur_angle = jax.vmap(calculate_torsion_angle)(positions[ind1s], positions[ind2s],
positions[ind3s], positions[ind4s])
# calculate the difference
diff = angle_difference(cur_angle, target) * dgrrdn
rest_pot = jnp.sum(mask * f1s * (1.0 - jnp.exp(-f2s * (diff)**2)))
return rest_pot


def calculate_energy_and_charges(positions,
Expand Down Expand Up @@ -132,8 +192,9 @@ def calculate_energy_and_charges_w_rest(positions,
force_field)
bond_rest_en = calculate_bond_restraint_energy(positions, structure)
angle_rest_en = calculate_angle_restraint_energy(positions, structure)
torsion_rest_en = calculate_torsion_restraint_energy(positions, structure)

energy = energy + bond_rest_en + angle_rest_en
energy = energy + bond_rest_en + angle_rest_en + torsion_rest_en
return energy, charges

def calculate_loss(force_field,
Expand Down

0 comments on commit 87b37ec

Please sign in to comment.