-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support no_grad in thunder.jit #1423
Conversation
@IvanYashchuk do you think this approach looks good or do you have anything else in mind? |
Co-authored-by: Ivan Yashchuk <[email protected]>
…hunder into support-no-grad
@@ -442,6 +442,11 @@ def get_computation_and_inputs(*args, **kwargs): | |||
# which seems to break the consistency of cache_info, leading to a failure in cache_info check. | |||
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) | |||
|
|||
# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform | |||
# to treat certain Symbols as constant. | |||
cache_info["is_grad_enabled"] = pytorch.is_grad_enabled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this cache_info entry needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well if you call with grad enabled and then without, you would want to have a cache miss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we think we don't need it, let's remove it in a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is for the following cases
Example 1
jfn = thunder.jit(fn)
with torch.no_grad():
jfn(x) # This will be compiled with no_grad
jfn(x) # We want this to be recompiled.
Example 2
jfn = thunder.jit(fn)
jfn(x)
with torch.no_grad():
jfn(x) # We want this to be recompiled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well if you call with grad enabled and then without, you would want to have a cache miss?
Of course.
How does this work if the content of cache_info["is_grad_enabled"]
is not checked anywhere in this pull request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In jit_ext.py
, we read the values from cache_info and add corresponding checks in prologue.
lightning-thunder/thunder/core/jit_ext.py
Lines 1580 to 1598 in 11a32a4
cache_info = thunder._get_cache_info() | |
# assert len of cache info to ensure that we're not missing anything? | |
if cache_info: | |
cache_info_p = Proxy(name="cache_info") | |
bsym = prims.unpack_cache_info.bind(cache_info_p, output=cache_info_p) | |
prologue_trace.bound_symbols.append(bsym) | |
for k, v in cache_info.items(): | |
p = proxy(v, name=f"cache_info_{k}", history=None) | |
bsym = prims.unpack_getitem.bind(cache_info_p, k, output=p) | |
prologue_trace.bound_symbols.append(bsym) | |
if isinstance(v, str): | |
clang.check_string_value(p, v) | |
elif isinstance(v, (int, bool, float)): | |
clang.check_number_type_and_value(p, v) | |
elif isinstance(v, (torch.dtype, torch.device)): | |
clang.check_literal_like(p, v) | |
else: | |
raise NotImplementedError(f"cache info of type {type(v).__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thank you @kshitij12345 @IvanYashchuk
Related #1420
In this PR, we support
no_grad
(andtorch._C._set_grad_enabled
) in user function passed tothunder.jit
.Approach:
There are 3 steps.
torch._C._set_grad_enabled
as a marker symbol which updates stateis_grad_enabled
in CompileData. These symbols are removed inremove_context_manager_prims_from_trace
pass.is_grad_enabled
on CompileData and accordingly tags it's output Proxy.DETACHED_AUTOGRAD_GRAPH
as constants. (We already have logic to treat Symbols as constant for VJP, we just add a new case here).Example and Forward-Backward Traces
First Trace
Execution Trace
Backward Execution Trace