diff --git a/apax/train/loss.py b/apax/train/loss.py index 56096455..cd28786e 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -52,6 +52,13 @@ def force_angle_exponential_weight( return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor +def stress_tril(label, prediction, divisor=1.0): + idxs = jnp.tril_indices(3) + label_tril = label[:, idxs[0], idxs[1]] + prediction_tril = prediction[:, idxs[0], idxs[1]] + return (label_tril - prediction_tril) ** 2 / divisor + + loss_functions = { "molecules": weighted_squared_error, "structures": weighted_squared_error, @@ -59,6 +66,7 @@ def force_angle_exponential_weight( "cosine_sim": force_angle_loss, "cosine_sim_div_magnitude": force_angle_div_force_label, "cosine_sim_exp_magnitude": force_angle_exponential_weight, + "tril": stress_tril, } @@ -101,6 +109,7 @@ def determine_divisor(self, n_atoms: jnp.array) -> jnp.array: n_atoms, "batch -> batch 1 1" ), "stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"), + "stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"), "stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"), } divisor = divisor_dict.get(divisor_id, jnp.array(1.0))