diff --git a/apax/train/loss.py b/apax/train/loss.py index 4a8f52e4..cd28786e 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -52,7 +52,7 @@ def force_angle_exponential_weight( return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor -def stress_tril(label, prediction, divisor= 1.0): +def stress_tril(label, prediction, divisor=1.0): idxs = jnp.tril_indices(3) label_tril = label[:, idxs[0], idxs[1]] prediction_tril = prediction[:, idxs[0], idxs[1]]