Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise when in place operations occur on leafs requiring grad #1458

Merged
merged 19 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING

from thunder.core.compile_data import get_compile_data
import thunder.core.prims as prims
from thunder.core.proxies import variableify, TensorProxy, unvariableify, ProxyInterface
from thunder.core.pytree import tree_flatten, tree_unflatten
Expand Down Expand Up @@ -499,8 +500,12 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl
copy_from_for_new_copy = reshaped_copy_from
else:
copy_from_for_new_copy = copy_from
new_copy_return = prims.copy_.meta(copy_from_for_new_copy, new_copy_to)
new_copy_bsym = prims.copy_.bind(copy_from_for_new_copy, new_copy_to, output=new_copy_return)
cd = get_compile_data()
grad_enabled = cd.is_grad_enabled if cd is not None else False
new_copy_return = prims.copy_.meta(copy_from_for_new_copy, new_copy_to, grad_enabled=grad_enabled)
new_copy_bsym = prims.copy_.bind(
copy_from_for_new_copy, new_copy_to, grad_enabled=grad_enabled, output=new_copy_return
)
copy_bsyms.append(new_copy_bsym)
else:
var_copy_to = variableify(copy_to)
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4030,6 +4030,8 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_
def copy__meta(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
grad_enabled: bool,
):
utils.check_type(copy_from, TensorProxy)
utils.check_type(copy_to, TensorProxy)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ def zeros_like(x):
prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)),
prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)),
prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)),
prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()),
prims.PrimIDs.COPY_: lambda x, y, grad_enabled: (prims.copy_(x, y, grad_enabled=grad_enabled), tuple()),
prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()),
}

Expand Down
3 changes: 3 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,8 @@ def var_mean(
def _copy__check(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
grad_enabled: bool,
) -> bool:
return are_supported_tensors(copy_from, copy_to)

Expand All @@ -2101,6 +2103,7 @@ def copy_(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
grad_enabled: bool,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
Expand Down
24 changes: 10 additions & 14 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,24 @@
from __future__ import annotations
import operator
import importlib
from dataclasses import replace
from contextlib import ContextDecorator
from functools import wraps, partial
from inspect import signature
from itertools import groupby
from functools import partial, wraps
from numbers import Number
from typing import TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Hashable, Sequence
from collections.abc import Sequence
from types import ModuleType
from enum import Enum, auto

import torch
import math
from looseversion import LooseVersion

from thunder.core.compile_data import get_compile_data
import thunder.core.dtypes as dtypes
from thunder.core.dtypes import to_torch_dtype, to_dtype
import thunder.core.devices as devices
from thunder.core.devices import to_torch_device, to_device
import thunder.core.prims as prims
from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol, BoundSymbol
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, pytype
from thunder.core.symbol import Symbol
from thunder.distributed.prims import DistributedReduceOps
import thunder.distributed.prims as dist_prims
import thunder.core.utils as utils
Expand Down Expand Up @@ -2202,12 +2194,16 @@ def is_float_type(self, input):
einops._backends._type2backend[TensorProxy] = EinopsThunderBackend()


def _copy__impl(copy_from, copy_to):
def _copy__impl(copy_from, copy_to, grad_enabled):
if grad_enabled and copy_to.is_leaf and copy_to.requires_grad:
raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.")
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
copy_to.copy_(copy_from)
return copy_to


copy_ = ex.register_operator("copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl)
copy_ = ex.register_operator(
"copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl, module=torch.Tensor
)
_register_implementation(prims.copy_, copy_, checker=_always_executable)


Expand Down
1 change: 1 addition & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _init_group(self, group, params, grads):
params.append(p)
grads.append(p.grad)

@torch.no_grad
def step(self):
for group in self.param_groups:
params = []
Expand Down
33 changes: 23 additions & 10 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import thunder
import thunder.core.dtypes as datatypes
import thunder.torch as ttorch
from thunder.tests.framework import instantiate, nvFuserExecutor
from thunder.tests.framework import instantiate, nvFuserExecutor, TorchExecutor


@instantiate(dtypes=datatypes.all_dtypes - datatypes.float_8bit_dtypes)
Expand All @@ -20,7 +20,7 @@ def torch_foo(x, y):
def foo(x, y):
z = x + y
# NOTE: nvfuserex doesn't support `return z`, i.e. the copy_from argument
o = thunder.core.prims.copy_(z, x)
o = thunder.core.prims.copy_(z, x, grad_enabled=True)
return o

traced_nvfuser_foo = executor.make_callable(foo)
Expand Down Expand Up @@ -49,7 +49,7 @@ def torch_foo(x, y):
def foo(x, y):
z = x * y
z = z * x
o = thunder.core.prims.copy_(z, x)
o = thunder.core.prims.copy_(z, x, grad_enabled=True)
p = y * y
return p

Expand Down Expand Up @@ -120,25 +120,25 @@ def forward(self, x):
def test_inplace_copy_sanity_check(executor, device, dtype):
def func0(x, y):
z = x * y
x = thunder.core.prims.copy_(z, x)
x = thunder.core.prims.copy_(z, x, grad_enabled=True)
return x + y

def func1(x, y):
z = x * y
o1 = thunder.core.prims.copy_(z, x)
o2 = thunder.core.prims.copy_(y, x)
o1 = thunder.core.prims.copy_(z, x, grad_enabled=True)
o2 = thunder.core.prims.copy_(y, x, grad_enabled=True)
return x, o1, o2

def func2(x, y):
z = x * y
o1 = thunder.core.prims.copy_(z, x)
o2 = thunder.core.prims.copy_(x, y)
o1 = thunder.core.prims.copy_(z, x, grad_enabled=True)
o2 = thunder.core.prims.copy_(x, y, grad_enabled=True)
return y, o1, o2

def func3(x, y):
z = x * y
o1 = thunder.core.prims.copy_(z, x)
o2 = thunder.core.prims.copy_(o1, y)
o1 = thunder.core.prims.copy_(z, x, grad_enabled=True)
o2 = thunder.core.prims.copy_(o1, y, grad_enabled=True)
return y, o2

for foo in (func0, func1, func2, func3):
Expand Down Expand Up @@ -178,3 +178,16 @@ def func(T0):
assert_close(a_ref, a)
for o, o_ref in zip(o_thunder, o_eager):
assert_close(o, o_ref)


@instantiate(executors=(TorchExecutor,), dtypes=datatypes.float_math_dtypes)
def test_inplace_copy_of_leaf_requiring_grad_fails(executor, device, dtype):
def fn(x):
x.copy_(x)

jitted_fn = executor.make_callable(fn)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=True)
with pytest.raises(RuntimeError):
jitted_fn(a)
4 changes: 2 additions & 2 deletions thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,11 @@ def f(xs, ys, z):
def test_inplace_to_tensors_with_grad(executor, device, _):
@torch.no_grad
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
def add_y(x, y):
x.add_(y, alpha=0.1)
return x.add_(y, alpha=0.1)

@torch.no_grad
def add_grad(x, y):
x.add_(x.grad, alpha=0.1)
return x.add_(x.grad, alpha=0.1)

for f in (add_y, add_grad):
jitted_f = executor.make_callable(f)
Expand Down
Loading
Loading