diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 6459fdec21..de040f5208 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -4,6 +4,7 @@ import torch import thunder +from thunder.tests.framework import instantiate, NOTHING from thunder.tests.make_tensor import make_tensor if TYPE_CHECKING: @@ -87,20 +88,35 @@ def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): return ScaleTensorSubclass(out, scales[0]) -def test_func_of_subclass_ctor_wrapper(): +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_of_subclass_ctor_wrapper(executor, device, _): def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: - return ScaleTensorSubclass(x, scale) + y = ScaleTensorSubclass(x, scale) + return y + + jitted = executor.make_callable(f) - device = torch.device("cuda") dtype = torch.float32 shape = (2, 2) x = make_tensor(shape, device=device, dtype=dtype) scale = make_tensor((), device=device, dtype=dtype) - jitted = thunder.jit(f) + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def f(x: torch.Tensor, scale: torch.Tensor): + y = ScaleTensorSubclass(x, scale) + z = ScaleTensorSubclass(y._x, y._scale) + return z + + jitted = executor.make_callable(f) expected = f(x, scale) actual = jitted(x, scale) - assert type(expected) is type(actual) torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + print(thunder.last_traces(jitted)[0]) + print(thunder.last_traces(jitted)[-1])