Skip to content

Commit

Permalink
support no_grad in thunder.jit (#1423)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Nov 18, 2024
1 parent d60f85c commit 11a32a4
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 13 deletions.
5 changes: 5 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ def get_computation_and_inputs(*args, **kwargs):
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
cache_info["is_grad_enabled"] = pytorch.is_grad_enabled()
cd.is_grad_enabled = pytorch.is_grad_enabled()

# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
Expand Down
4 changes: 4 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def __init__(
# State for pytorch autocast context managers.
self.autocast_stack: AutocastStack = AutocastStack()

# State to query whether grad is enabled or disabled using
# torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled
self.is_grad_enabled: bool = True

#
# Gathers additional metadata
#
Expand Down
17 changes: 16 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy
from thunder.core.proxies import Proxy, TensorProxy, NumberProxy, variableify, CollectionProxy, ProxyTag
from thunder.core.compile_data import get_compile_data

from thunder.core.trace import (
get_tracectx,
Expand Down Expand Up @@ -320,6 +321,20 @@ def __call__(self, *args, **kwargs):
result = self.meta(*args, **kwargs)
trace.pop_scope()

cd = get_compile_data()
if cd is not None and not cd.is_grad_enabled:
# If grad is disabled using `torch.no_grad` or `torch._C._set_grad_enabled(False)`,
# tag the results with `DETACHED_AUTOGRAD_GRAPH` which makes this Symbol a constant for
# vjp transform (applied later).
def tag_tensorproxy_output_as_detached(proxy):
if isinstance(proxy, TensorProxy):
# We need to remove name from trace, otherwise replace will return a proxy with new name.
trace.names.remove(proxy.name)
return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))
return proxy

result = tree_map(tag_tensorproxy_output_as_detached, result)

bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)
symbols_list = trace.peek_scope()

Expand Down
8 changes: 8 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
variableify,
unvariableify,
FutureTensorProxy,
ProxyTag,
)
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
bool: True if the symbol is constant, False otherwise.
"""
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
# NOTE - `any(()) is False`
output_disconnected_from_graph = any(
ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy)
)
return (
are_all_args_non_differentiable
or symbol.are_all_args_constant
or symbol.sym.id in nondifferentiable_vjp_symbols
or output_disconnected_from_graph
)


Expand Down
72 changes: 63 additions & 9 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,7 +2099,7 @@ def func(x):
compiled = executor.make_callable(func)
out = compiled(x)
assert out is x
initial_trace_with_dce = thunder.last_traces(compiled)[3]
initial_trace_with_dce = thunder.last_traces(compiled)[4]
assert "Constructed by Dead Code Elimination" in str(initial_trace_with_dce)
assert len(initial_trace_with_dce.bound_symbols) == 2
assert initial_trace_with_dce.bound_symbols[0].sym.id == prims.PrimIDs.UNPACK_TRIVIAL
Expand Down Expand Up @@ -2480,27 +2480,81 @@ def foo_error(args):


def test_grad_ctx():
# NOTE - This test would start failing if tags on Proxies are dropped
# as the computation under `no_grad` won't be treated as constant
# and grad won't match with PyTorch eager.

# Test `enable_grad` on a function works correctly
@torch.enable_grad()
def foo1(x):
return x + 1

x = torch.randn(3, 3, requires_grad=True)
with pytest.warns(UserWarning, match="have no effect under thunder.jit"):
thunder.jit(foo1)(x).sum().backward()

thunder.jit(foo1)(x).sum().backward()
assert x.grad is not None

# Test `no_grad` on a function works correctly
@torch.no_grad()
def foo2(x):
return x + 1

x = torch.randn(3, 3, requires_grad=True)
with pytest.warns(UserWarning, match="have no effect under thunder.jit"):
thunder.jit(foo2)(x).sum().backward()
thunder.jit(foo2)(x).sum().backward()
assert x.grad is None

# `torch.no_grad` has no effect on thunder's autodiff which determines whether to compute grad based on `requires_grad=True`.
# Thus when backward is called it computes grad for the input.
assert x.grad is not None
# Test `no_grad` ctx correctly disable gradient computation
def foo3(x):
with torch.no_grad():
y = x * 3
return x * 2 + y

x = torch.randn(3, 3, requires_grad=True)
with torch.no_grad():
x_ref = x.clone()
x_ref.requires_grad_(True)

foo3(x_ref).sum().backward()
thunder.jit(foo3)(x).sum().backward()
# Verify the gradients match
torch.testing.assert_close(x.grad, x_ref.grad)

# Test nested `no_grad` and `enable_grad`
def foo4(x):
with torch.enable_grad():
with torch.no_grad():
y = x * 3
z = x * 4
return x * 2 + y + z

x = torch.randn(3, 3, requires_grad=True)
with torch.no_grad():
x_ref = x.clone()
x_ref.requires_grad_(True)

foo4(x_ref).sum().backward()
thunder.jit(foo4)(x).sum().backward()
# Verify the gradients match
torch.testing.assert_close(x.grad, x_ref.grad)

def foo5(x):
return x * 2

x = torch.randn(3, 3, requires_grad=True)
with torch.no_grad():
x_ref = x.clone()
x_ref.requires_grad_(True)

jfoo = thunder.jit(foo5)
with torch.no_grad():
o = jfoo(x)
assert o.grad_fn is None
assert thunder.cache_misses(jfoo) == 1 # First compilation

# Running it out of `torch.no_grad`, should lead to recompile.
foo5(x_ref).sum().backward()
jfoo(x).sum().backward()
torch.testing.assert_close(x.grad, x_ref.grad)
assert thunder.cache_misses(jfoo) == 2


def test_serialize_trace():
Expand Down
20 changes: 17 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ListProxy,
DictProxy,
numberproxy,
ProxyTag,
)
from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -5238,11 +5239,24 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device:
register_function(torch.device, torch_device)


def _set_grad_enabled_with_warning(enabled: bool) -> None:
warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit")
# Tag to use on Proxies created in `no_grad` regions.
# VJP transform will treat BoundSymbol's whose output has these tags
# as constant.
ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH")


register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)
# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols
# to find the `no_grad` regions and mark the BoundSymbols within them as constant
# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag.
@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,))
def _set_grad_enabled_with_warning(enabled: bool) -> None:
cd = get_compile_data()
if cd is None:
warnings.warn(
"torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect, use thunder.jit for correct behaviour"
)
return
get_compile_data().is_grad_enabled = enabled


def _unwrap_if_dead(tensor):
Expand Down

0 comments on commit 11a32a4

Please sign in to comment.