diff --git a/olmo/train.py b/olmo/train.py index 105f82e40..4bcfe6a98 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -135,7 +135,8 @@ def cross_entropy_loss( z_squared = logits.logsumexp(-1).pow(2) if reduction == "mean": - z_squared = (z_squared * (labels != ignore_index)).mean() + mask = labels != ignore_index + z_squared = (z_squared * mask).sum() / mask.sum() elif reduction == "sum": z_squared = (z_squared * (labels != ignore_index)).sum()