Skip to content

Commit

Permalink
add torsion error in the loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
cagrikymk committed Feb 17, 2024
1 parent 2813449 commit 28280c2
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions jaxreaxff/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,29 @@ def calculate_loss(force_field,
cur_angle = safe_mask((cur_angle < 1) & (cur_angle > -1),
jnp.arccos, cur_angle).astype(disp12.dtype)
# convert it to degree
angles = cur_angle * rdndgr

# to have periodicity, Ex. diff between 170 and -170 is 20 degree.
angles = jnp.where(angles < 0.0, angles+360.0, angles)
targets = angle_items.target
targets = jnp.where(targets < 0.0, targets+360.0, targets)

angle_errors = ((targets - angles) /
angle_items.weight) ** 2
cur_angle = cur_angle * rdndgr
targets = angle_items.target
diff = angle_difference(cur_angle, targets)
angle_errors = (diff / angle_items.weight) ** 2
angle_error = jnp.sum(angle_errors)
total_error += angle_error
if return_indiv_error:
all_indiv_errors['ANGLE'] = [angles, targets, angle_errors]
all_indiv_errors['ANGLE'] = [cur_angle, targets, angle_errors]

if training_data.torsion_items != None:
torsion_items = training_data.torsion_items
pos1 = all_positions[torsion_items.sys_ind,torsion_items.a1_ind]
pos2 = all_positions[torsion_items.sys_ind,torsion_items.a2_ind]
pos3 = all_positions[torsion_items.sys_ind,torsion_items.a3_ind]
pos4 = all_positions[torsion_items.sys_ind,torsion_items.a4_ind]
cur_angle = jax.vmap(calculate_torsion_angle)(pos1, pos2, pos3, pos4)
targets = torsion_items.target
diff = angle_difference(cur_angle, targets)
torsion_errors = (diff / torsion_items.weight) ** 2
torsion_error = jnp.sum(torsion_errors)
total_error += torsion_error
if return_indiv_error:
all_indiv_errors['TORSION'] = [cur_angle, targets, torsion_errors]
if return_indiv_error:
return total_error, all_indiv_errors
return total_error
Expand Down

0 comments on commit 28280c2

Please sign in to comment.