Skip to content

Commit

Permalink
Adding runtime warning for checkpointing inputs to have requires_grad…
Browse files Browse the repository at this point in the history
…=True (pytorch#6883)

* Adding the warning for the checkpointing inputs to have requires_grad=True

* fix bug
  • Loading branch information
prigoyal authored and soumith committed Apr 24, 2018
1 parent 9765bb5 commit 7d32f6f
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import warnings


def detach_variable(inputs):
Expand All @@ -14,10 +15,16 @@ def detach_variable(inputs):
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)


def check_backward_validity(inputs):
if not any(inp.requires_grad for inp in inputs):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")


class CheckpointFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, run_function, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.save_for_backward(*args)
with torch.no_grad():
Expand Down Expand Up @@ -66,6 +73,11 @@ def checkpoint(function, *args):
checkpointed version won't be equivalent, and unfortunately it can't be
detected.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
Expand Down Expand Up @@ -96,6 +108,11 @@ def checkpoint_sequential(functions, segments, *inputs):
Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
with :func:`torch.autograd.backward`.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or
functions (comprising the model) to run sequentially.
Expand Down

0 comments on commit 7d32f6f

Please sign in to comment.