diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 4e69c4ee..ccbbf1cf 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -236,11 +236,11 @@ def train( loss.backward() # Gradient Clipping. - # for p in gflownet.parameters(): - # if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. - # p.grad.data.clamp_( - # -gradient_clip_value, gradient_clip_value - # ).nan_to_num_(0.0) + for p in gflownet.parameters(): + if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. + p.grad.data.clamp_( + -gradient_clip_value, gradient_clip_value + ).nan_to_num_(0.0) optimizer.step() states_visited += len(trajectories)