Skip to content

Commit

Permalink
update impl to use compiledata
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Nov 13, 2024
1 parent 274f284 commit 91859f2
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 55 deletions.
7 changes: 3 additions & 4 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
wrap_return_value_together_with_argments,
unwrap_return_value,
remove_context_manager_prims_from_trace,
tag_no_grad_symbols_pass,
)
from thunder.core.functionalization import (
check_inplace_to_views,
Expand Down Expand Up @@ -443,6 +442,9 @@ def get_computation_and_inputs(*args, **kwargs):
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

cache_info["is_grad_enabled"] = pytorch.is_grad_enabled()
cd.is_grad_enabled = pytorch.is_grad_enabled()

# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
Expand Down Expand Up @@ -539,9 +541,6 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = wrap_return_value_together_with_argments(computation_trc)
computation_traces.append(computation_trc)

computation_trc = tag_no_grad_symbols_pass(computation_trc)
computation_traces.append(computation_trc)

computation_trc = remove_context_manager_prims_from_trace(computation_trc)
computation_traces.append(computation_trc)

Expand Down
2 changes: 2 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def __init__(
# State for pytorch autocast context managers.
self.autocast_stack: AutocastStack = AutocastStack()

self.is_grad_enabled: bool = True

#
# Gathers additional metadata
#
Expand Down
12 changes: 11 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy
from thunder.core.proxies import Proxy, TensorProxy, NumberProxy, variableify, CollectionProxy, ProxyTag
from thunder.core.compile_data import get_compile_data

from thunder.core.trace import (
get_tracectx,
Expand Down Expand Up @@ -320,6 +321,15 @@ def __call__(self, *args, **kwargs):
result = self.meta(*args, **kwargs)
trace.pop_scope()

if not get_compile_data().is_grad_enabled:

def tag_tensorproxy_output_as_detached(proxy):
if isinstance(proxy, TensorProxy):
return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))
return proxy

result = tree_map(tag_tensorproxy_output_as_detached, result)

bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)
symbols_list = trace.peek_scope()

Expand Down
49 changes: 0 additions & 49 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,52 +492,3 @@ def is_context_manager_prim(bsym):
new_trace.bound_symbols = filtered_bsyms
new_trace.set_provenance(TraceProvenance("Remove context manager prims"))
return new_trace


def tag_no_grad_symbols_pass(trace: Trace) -> Trace:
"""
This function iterates over trace and marks the BoundSymbols
in `no_grad` regions such that VJP pass will treat them as constant.
"""
is_no_grad_region = False

# NOTE - This will also copy name from original trace.
new_trace = from_trace(trace)
new_bsyms = []

for bsym in trace.bound_symbols:
# case - torch._C._set_grad_enabled(False)
if bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and not bsym.args[0]:
is_no_grad_region = True
continue
# case - torch._C._set_grad_enabled(True)
elif bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and bsym.args[0]:
is_no_grad_region = False
continue

if is_no_grad_region:
# Mark the TensorProxy output of the `bsym`
# with `ProxyTag.DETACHED_AUTOGRAD_GRAPH` so that
# vjp will treat this as constant.

def create_detached_output(t):
if isinstance(t, TensorProxy):
# NOTE - We need `tracectx` as creating/replacing name for proxy
# tries a look-up in current trace.
with tracectx(new_trace):
# Remove the name so that we can re-use it.
# Otherwise, we get a proxy with new name.
new_trace.names.remove(t.name)
return t.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))

return t

new_output = tree_map(create_detached_output, bsym.output)
# Create a copy of the `bsym` with `new_output`
bsym = bsym.from_bsym(output=new_output)

new_bsyms.append(bsym)

new_trace.bound_symbols = new_bsyms
new_trace.set_provenance(TraceProvenance("no_grad detach graph for vjp pass"))
return new_trace
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5238,7 +5238,7 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device:
# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag.
@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,))
def _set_grad_enabled_with_warning(enabled: bool) -> None:
pass
get_compile_data().is_grad_enabled = enabled


def _unwrap_if_dead(tensor):
Expand Down

0 comments on commit 91859f2

Please sign in to comment.