Skip to content

Commit

Permalink
apply function with compile data
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Nov 27, 2024
1 parent 590ef18 commit 3fb1f53
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 15 deletions.
4 changes: 2 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,8 @@ def fn_(*args, **kwargs) -> Any:
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)

check_storage_aliases(cache_entry, inps)

result = cache_entry.computation_fn(*inps)
with compile_data_and_stats(cd, cs):
result = cache_entry.computation_fn(*inps)
result = maybe_connect_to_autograd(cache_entry, result)
result = maybe_call_epilogue(cache_entry, result, pro_to_epi)

Expand Down
2 changes: 0 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4030,8 +4030,6 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_
def copy__meta(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
is_grad_enabled: bool = False,
):
utils.check_type(copy_from, TensorProxy)
utils.check_type(copy_to, TensorProxy)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,7 @@ def zeros_like(x):
prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)),
prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)),
prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)),
prims.PrimIDs.COPY_: lambda x, y, is_grad_enabled: (prims.copy_(x, y, is_grad_enabled=is_grad_enabled), tuple()),
prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()),
prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()),
}

Expand Down
2 changes: 0 additions & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,7 +2054,6 @@ def var_mean(
def _copy__check(
copy_from: TensorProxy,
copy_to: TensorProxy,
is_grad_enabled: bool,
) -> bool:
return are_supported_tensors(copy_from, copy_to)

Expand All @@ -2065,7 +2064,6 @@ def copy_(
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
is_grad_enabled: bool,
) -> Any:
nvcopy_from = getnv(copy_from, fd, lc_to_nv_map)
nvcopy_to = getnv(copy_to, fd, lc_to_nv_map)
Expand Down
5 changes: 3 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2181,8 +2181,9 @@ def is_float_type(self, input):
einops._backends._type2backend[TensorProxy] = EinopsThunderBackend()


def _copy__impl(copy_from, copy_to, *, is_grad_enabled):
if is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad:
def _copy__impl(copy_from, copy_to):
cd = get_compile_data()
if cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad:
raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.")
copy_to.copy_(copy_from)
return copy_to
Expand Down
3 changes: 2 additions & 1 deletion thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def step(self):

optimizer = Optimizer([a, b])
cstep = executor.make_callable(optimizer.step)
cstep()
with torch.no_grad():
cstep()

expected_a = ref_a - 0.1 * a.grad
assert_close(a, expected_a)
Expand Down
7 changes: 4 additions & 3 deletions thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,11 @@ def f(xs, ys, z):
def test_inplace_to_tensors_with_grad(executor, device, _):
@torch.no_grad
def add_grad(x, y):
x.add_(x.grad)
return x.add_(x.grad)

@torch.no_grad
def add_y(x, y):
x.add_(y, alpha=0.1)
return x.add_(y, alpha=0.1)

for fn in (add_grad, add_y):
jitted_f = executor.make_callable(fn)
Expand All @@ -494,7 +494,8 @@ def add_y(x, y):
x_ref.grad = x.grad.clone().detach()
y_ref = y.clone().detach()

res = jitted_f(x, y)
with torch.no_grad():
res = jitted_f(x, y)
res_ref = fn(x_ref, y_ref)

torch.testing.assert_close(x, x_ref)
Expand Down
3 changes: 1 addition & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,8 +1960,7 @@ def copysign_(a, b, /):

@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
cd = get_compile_data()
return prims.copy_(b, a, is_grad_enabled=cd.is_grad_enabled if cd is not None else False)
return prims.copy_(b, a)


# TODO Implement div
Expand Down

0 comments on commit 3fb1f53

Please sign in to comment.