Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support no_grad in thunder.jit #1423

Merged
merged 13 commits into from
Nov 18, 2024
Merged

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Nov 11, 2024

Related #1420

In this PR, we support no_grad (and torch._C._set_grad_enabled) in user function passed to thunder.jit.

Approach:

There are 3 steps.

  1. We add torch._C._set_grad_enabled as a marker symbol which updates state is_grad_enabled in CompileData. These symbols are removed in remove_context_manager_prims_from_trace pass.
  2. When a Symbol is called during tracing, it queries is_grad_enabled on CompileData and accordingly tags it's output Proxy.
  3. We update VJP transform to treat Bsyms whose output contains DETACHED_AUTOGRAD_GRAPH as constants. (We already have logic to treat Symbols as constant for VJP, we just add a new case here).
Example and Forward-Backward Traces
import torch
import thunder

def fn(x):
    with torch.no_grad():
        y = x * torch.ones(3,)
    return x + y

x = torch.ones(3, requires_grad=True)
jfn = thunder.jit(fn)

o = jfn(x)
o.sum().backward()
print(x.grad)

# Check against PyTorch eager
# x = torch.ones(3, requires_grad=True)
# o = fn(x)
# o.sum().backward()
# print(x.grad)

First Trace

def computation(x):
  # x: "cpu f32[3]"

  # /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187:             torch._C._set_grad_enabled(mode)
  ltorch._set_grad_enabled_with_warning(False)

  # /home/kkalambarkar/lightning-thunder/scratchpad/test_no_grad.py:6:          y = x * torch.ones(3,)
  t0 = ltorch.ones(3, device=None, dtype=None)  # t0: "cpu f32[3]"
    # t0 = ltorch.full((3,), 1, device=None, dtype=torch.float32)  # t0: "cpu f32[3]"
      # t0 = prims.full((3,), 1, device=devices.Device("cpu"), dtype=dtypes.float32)  # t0: "cpu f32[3]"
  t1 = ltorch.mul(x, t0)  # t1: "cpu f32[3]"
    # t1 = prims.mul(x, t0)  # t1: "cpu f32[3]"

  # /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187:             torch._C._set_grad_enabled(mode)
  ltorch._set_grad_enabled_with_warning(True)

  # /home/kkalambarkar/lightning-thunder/scratchpad/test_no_grad.py:7:      return x + y
  t2 = ltorch.add(x, t1, alpha=1)  # t2: "cpu f32[3]"
    # t2 = prims.add(x, t1)  # t2: "cpu f32[3]"
  return t2

Execution Trace

def computation(x):
  # x: "cpu f32[3]"
  t0 = torch.ones(3, device=None, dtype=None)  # t0: "cpu f32[3]"
    # t0 = ltorch.ones(3, device=None, dtype=None)  # t0: "cpu f32[3]"
      # t0 = ltorch.full((3,), 1, device=None, dtype=torch.float32)  # t0: "cpu f32[3]"
        # t0 = prims.full((3,), 1, device=devices.Device("cpu"), dtype=dtypes.float32)  # t0: "cpu f32[3]"
  t1 = torch.mul(x, t0)  # t1: "cpu f32[3]"
    # t1 = ltorch.mul(x, t0)  # t1: "cpu f32[3]"
      # t1 = prims.mul(x, t0)  # t1: "cpu f32[3]"
  del t0
  t2 = torch.add(x, t1, alpha=1)  # t2: "cpu f32[3]"
    # t2 = ltorch.add(x, t1, alpha=1)  # t2: "cpu f32[3]"
      # t2 = prims.add(x, t1)  # t2: "cpu f32[3]"
  del t1
  return {'output': t2, 'flat_args': [x], 'flat_output': (t2,)}, ((), ())

Backward Execution Trace

def backward_fn(saved_for_backward, cotangents):
  # cotangents: "Collection"
  t6, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  return (t6,)

@kshitij12345
Copy link
Collaborator Author

@IvanYashchuk do you think this approach looks good or do you have anything else in mind?

@kshitij12345 kshitij12345 marked this pull request as ready for review November 13, 2024 15:54
@@ -442,6 +442,11 @@ def get_computation_and_inputs(*args, **kwargs):
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
cache_info["is_grad_enabled"] = pytorch.is_grad_enabled()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cache_info entry needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well if you call with grad enabled and then without, you would want to have a cache miss?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we think we don't need it, let's remove it in a follow-up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the following cases

Example 1

jfn = thunder.jit(fn)
with torch.no_grad():
    jfn(x)   # This will be compiled with no_grad

jfn(x)   # We want this to be recompiled.

Example 2

jfn = thunder.jit(fn)

jfn(x)

with torch.no_grad():
    jfn(x)   # We want this to be recompiled

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well if you call with grad enabled and then without, you would want to have a cache miss?

Of course.

How does this work if the content of cache_info["is_grad_enabled"] is not checked anywhere in this pull request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In jit_ext.py, we read the values from cache_info and add corresponding checks in prologue.

cache_info = thunder._get_cache_info()
# assert len of cache info to ensure that we're not missing anything?
if cache_info:
cache_info_p = Proxy(name="cache_info")
bsym = prims.unpack_cache_info.bind(cache_info_p, output=cache_info_p)
prologue_trace.bound_symbols.append(bsym)
for k, v in cache_info.items():
p = proxy(v, name=f"cache_info_{k}", history=None)
bsym = prims.unpack_getitem.bind(cache_info_p, k, output=p)
prologue_trace.bound_symbols.append(bsym)
if isinstance(v, str):
clang.check_string_value(p, v)
elif isinstance(v, (int, bool, float)):
clang.check_number_type_and_value(p, v)
elif isinstance(v, (torch.dtype, torch.device)):
clang.check_literal_like(p, v)
else:
raise NotImplementedError(f"cache info of type {type(v).__name__}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thank you @kshitij12345 @IvanYashchuk

@t-vi t-vi enabled auto-merge (squash) November 18, 2024 09:41
@t-vi t-vi merged commit 11a32a4 into Lightning-AI:main Nov 18, 2024
41 checks passed
@kshitij12345 kshitij12345 deleted the support-no-grad branch November 18, 2024 10:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants