diff --git a/thunder/__init__.py b/thunder/__init__.py index 5f9bd9f52..5f4b75d6e 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -442,6 +442,11 @@ def get_computation_and_inputs(*args, **kwargs): # which seems to break the consistency of cache_info, leading to a failure in cache_info check. cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) + # Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform + # to treat certain Symbols as constant. + cache_info["is_grad_enabled"] = pytorch.is_grad_enabled() + cd.is_grad_enabled = pytorch.is_grad_enabled() + # TODO RC1 Add module and function checks to prologue (make it a compile option) # Checks cache diff --git a/thunder/common.py b/thunder/common.py index bc5f37015..674cab65d 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -221,6 +221,10 @@ def __init__( # State for pytorch autocast context managers. self.autocast_stack: AutocastStack = AutocastStack() + # State to query whether grad is enabled or disabled using + # torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled + self.is_grad_enabled: bool = True + # # Gathers additional metadata # diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index c34071c1c..da6eca6dd 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -21,7 +21,8 @@ from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map import thunder.core.dtypes as dtypes import thunder.core.devices as devices -from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy +from thunder.core.proxies import Proxy, TensorProxy, NumberProxy, variableify, CollectionProxy, ProxyTag +from thunder.core.compile_data import get_compile_data from thunder.core.trace import ( get_tracectx, @@ -320,6 +321,20 @@ def __call__(self, *args, **kwargs): result = self.meta(*args, **kwargs) trace.pop_scope() + cd = get_compile_data() + if cd is not None and not cd.is_grad_enabled: + # If grad is disabled using `torch.no_grad` or `torch._C._set_grad_enabled(False)`, + # tag the results with `DETACHED_AUTOGRAD_GRAPH` which makes this Symbol a constant for + # vjp transform (applied later). + def tag_tensorproxy_output_as_detached(proxy): + if isinstance(proxy, TensorProxy): + # We need to remove name from trace, otherwise replace will return a proxy with new name. + trace.names.remove(proxy.name) + return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) + return proxy + + result = tree_map(tag_tensorproxy_output_as_detached, result) + bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols) symbols_list = trace.peek_scope() diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 2d9d88cdd..7b09ef26b 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -33,6 +33,7 @@ variableify, unvariableify, FutureTensorProxy, + ProxyTag, ) from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool: bool: True if the symbol is constant, False otherwise. """ are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) + # Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. + # These are treated as constant for VJP. + # NOTE - `any(()) is False` + output_disconnected_from_graph = any( + ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy) + ) return ( are_all_args_non_differentiable or symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols + or output_disconnected_from_graph ) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4330b5236..29c491be0 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2099,7 +2099,7 @@ def func(x): compiled = executor.make_callable(func) out = compiled(x) assert out is x - initial_trace_with_dce = thunder.last_traces(compiled)[3] + initial_trace_with_dce = thunder.last_traces(compiled)[4] assert "Constructed by Dead Code Elimination" in str(initial_trace_with_dce) assert len(initial_trace_with_dce.bound_symbols) == 2 assert initial_trace_with_dce.bound_symbols[0].sym.id == prims.PrimIDs.UNPACK_TRIVIAL @@ -2480,27 +2480,81 @@ def foo_error(args): def test_grad_ctx(): + # NOTE - This test would start failing if tags on Proxies are dropped + # as the computation under `no_grad` won't be treated as constant + # and grad won't match with PyTorch eager. + + # Test `enable_grad` on a function works correctly @torch.enable_grad() def foo1(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - with pytest.warns(UserWarning, match="have no effect under thunder.jit"): - thunder.jit(foo1)(x).sum().backward() - + thunder.jit(foo1)(x).sum().backward() assert x.grad is not None + # Test `no_grad` on a function works correctly @torch.no_grad() def foo2(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - with pytest.warns(UserWarning, match="have no effect under thunder.jit"): - thunder.jit(foo2)(x).sum().backward() + thunder.jit(foo2)(x).sum().backward() + assert x.grad is None - # `torch.no_grad` has no effect on thunder's autodiff which determines whether to compute grad based on `requires_grad=True`. - # Thus when backward is called it computes grad for the input. - assert x.grad is not None + # Test `no_grad` ctx correctly disable gradient computation + def foo3(x): + with torch.no_grad(): + y = x * 3 + return x * 2 + y + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + foo3(x_ref).sum().backward() + thunder.jit(foo3)(x).sum().backward() + # Verify the gradients match + torch.testing.assert_close(x.grad, x_ref.grad) + + # Test nested `no_grad` and `enable_grad` + def foo4(x): + with torch.enable_grad(): + with torch.no_grad(): + y = x * 3 + z = x * 4 + return x * 2 + y + z + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + foo4(x_ref).sum().backward() + thunder.jit(foo4)(x).sum().backward() + # Verify the gradients match + torch.testing.assert_close(x.grad, x_ref.grad) + + def foo5(x): + return x * 2 + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + jfoo = thunder.jit(foo5) + with torch.no_grad(): + o = jfoo(x) + assert o.grad_fn is None + assert thunder.cache_misses(jfoo) == 1 # First compilation + + # Running it out of `torch.no_grad`, should lead to recompile. + foo5(x_ref).sum().backward() + jfoo(x).sum().backward() + torch.testing.assert_close(x.grad, x_ref.grad) + assert thunder.cache_misses(jfoo) == 2 def test_serialize_trace(): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index b216d2a68..a94ada1cc 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -47,6 +47,7 @@ ListProxy, DictProxy, numberproxy, + ProxyTag, ) from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten from thunder.core.symbol import Symbol @@ -5238,11 +5239,24 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: register_function(torch.device, torch_device) -def _set_grad_enabled_with_warning(enabled: bool) -> None: - warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit") +# Tag to use on Proxies created in `no_grad` regions. +# VJP transform will treat BoundSymbol's whose output has these tags +# as constant. +ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH") -register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) +# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols +# to find the `no_grad` regions and mark the BoundSymbols within them as constant +# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag. +@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,)) +def _set_grad_enabled_with_warning(enabled: bool) -> None: + cd = get_compile_data() + if cd is None: + warnings.warn( + "torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect, use thunder.jit for correct behaviour" + ) + return + get_compile_data().is_grad_enabled = enabled def _unwrap_if_dead(tensor):