Skip to content

Commit

Permalink
attribute access to subclass proxy seems functioning
Browse files Browse the repository at this point in the history
somehow, apparently

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 3, 2024
1 parent 8b4f7fd commit 0b69d52
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

import thunder
from thunder.tests.framework import instantiate, NOTHING
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,20 +88,35 @@ def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any):
return ScaleTensorSubclass(out, scales[0])


def test_func_of_subclass_ctor_wrapper():
@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_of_subclass_ctor_wrapper(executor, device, _):

def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:
return ScaleTensorSubclass(x, scale)
y = ScaleTensorSubclass(x, scale)
return y

jitted = executor.make_callable(f)

device = torch.device("cuda")
dtype = torch.float32
shape = (2, 2)
x = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

jitted = thunder.jit(f)
expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

def f(x: torch.Tensor, scale: torch.Tensor):
y = ScaleTensorSubclass(x, scale)
z = ScaleTensorSubclass(y._x, y._scale)
return z

jitted = executor.make_callable(f)

expected = f(x, scale)
actual = jitted(x, scale)
assert type(expected) is type(actual)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))
print(thunder.last_traces(jitted)[0])
print(thunder.last_traces(jitted)[-1])

0 comments on commit 0b69d52

Please sign in to comment.