diff --git a/flamby/benchmarks/benchmark_utils.py b/flamby/benchmarks/benchmark_utils.py index f9e46f2d8..a7498a576 100644 --- a/flamby/benchmarks/benchmark_utils.py +++ b/flamby/benchmarks/benchmark_utils.py @@ -438,7 +438,8 @@ def train_single_centric( grad_norm = 0 for param in model.parameters(): - grad_norm += torch.linalg.norm(param.grad) + if param.grad is not None: + grad_norm += torch.linalg.norm(param.grad) grad_norm_history.append(grad_norm) return model