From e76a98cd53dc09d8d4344826b74b706c2a2177de Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 2 Nov 2024 21:21:16 +0900 Subject: [PATCH 1/7] workaround for `.__init__` call on the output of `_make_wrapper_subclass` lookaside Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 37 ++++++ thunder/core/proxies.py | 25 +++++ thunder/tests/nvfuser_repro.py | 28 +++++ thunder/tests/test_tensor_subclass.py | 155 ++++++++++++++++++++++++++ thunder/tests/test_torchao_float8.py | 28 +++++ 5 files changed, 273 insertions(+) create mode 100644 thunder/tests/nvfuser_repro.py create mode 100644 thunder/tests/test_tensor_subclass.py create mode 100644 thunder/tests/test_torchao_float8.py diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 7b1cf3eb87..cfbb8c8a9f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -62,6 +62,7 @@ NumberProxy, StringProxy, TensorProxy, + SubclassTensorProxy, FutureTensorProxy, make_proxy_name, Variable, @@ -757,6 +758,42 @@ def grad_transform(*args, **kwargs): return forward_result +@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) +def _make_wrapper_subclass( + cls: torch._C._TensorMeta, + size: Sequence[int], + strides: Sequence[int] | None = None, + storage_offset: int | None = None, + memory_format: torch.memory_format | None = None, + dtype: torch.dtype | None = None, + layout: torch.layout | None = torch.strided, + device: torch.device | None = None, + pin_memory: bool = False, + requires_grad: bool = False, + dispatch_sizes_strides_policy: str | None = None, + dispatch_device: bool = False, + dispatch_layout: bool = False, + _extra_dispatch_keys: torch.DispatchKeySet | None = None, + storage_size: int | None = None, +): + ucls = unwrap(cls) + usize = unwrap(size) + udtype = unwrap(dtype) + udevice = unwrap(device) + urequires_grad = unwrap(requires_grad) + + subclass = SubclassTensorProxy( + None, + shape=usize, + device=udevice, + dtype=udtype, + requires_grad=urequires_grad, + history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), + subclass_type=ucls, + ) + return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance])) + + @register_general_jit_lookaside(torch.autocast.__enter__) def autocast_enter(autocast_obj): unwrap_autocast_obj = unwrap(autocast_obj) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 2f2eb1c665..5118caa3c3 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1880,6 +1880,31 @@ def real(self): return method(self) +class SubclassTensorProxy(TensorProxy): + _tensors: list[TensorProxy] + _non_tensors: list[Any] + _subclass_type: torch._C._TensorMeta + + def __init__(self, *args, **kwargs): + tensors = kwargs.pop("tensors", []) + non_tensors = kwargs.pop("non_tensors", []) + subclass_type = kwargs.pop("subclass_type", None) + if not hasattr(self, "_name"): + super().__init__(*args, **kwargs) + self._tensors = tensors + self._non_tensors = non_tensors + self._subclass_type = subclass_type + else: + from thunder.core.pytree import tree_flatten + + flat_args, spec = tree_flatten((args, kwargs)) + self._tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) + self._non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) + baseutils.check( + self.history is not None, lambda: f"SubclassTensorProxy {self._name} must have its `history` set" + ) + + class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface): def __init__( self, diff --git a/thunder/tests/nvfuser_repro.py b/thunder/tests/nvfuser_repro.py new file mode 100644 index 0000000000..f48708dc5e --- /dev/null +++ b/thunder/tests/nvfuser_repro.py @@ -0,0 +1,28 @@ +# CUDA devices: +# 0: NVIDIA RTX 6000 Ada Generation +# torch version: 2.6.0a0+git408fe41 +# cuda version: 12.6 +# nvfuser version: 0.2.11+gitaad7286 +import torch +from nvfuser import FusionDefinition, DataType + + +def nvfuser_fusion_id0(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[2, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0] + ) + T1 = fd.define_tensor( + shape=[2, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0] + ) + T2 = fd.ops.add(T0, T1) + fd.add_output(T2) + + +with FusionDefinition() as fd: + nvfuser_fusion_id0(fd) + +inputs = [ + torch.randn(4, dtype=torch.float32, device="cuda:0").as_strided((2, 2), (2, 1)), + torch.randn(4, dtype=torch.float32, device="cuda:0").as_strided((2, 2), (2, 1)), +] +fd.execute(inputs) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py new file mode 100644 index 0000000000..0481885044 --- /dev/null +++ b/thunder/tests/test_tensor_subclass.py @@ -0,0 +1,155 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch + +import thunder +from thunder.tests.make_tensor import make_tensor + +if TYPE_CHECKING: + from typing import Any + + +class ScaleTensorSubclass(torch.Tensor): + _x: torch.Tensor + _scale: torch.Tensor + __slots__ = ["_x", "_scale"] + + def __new__(cls, x: torch.Tensor, scale: torch.Tensor): + assert scale.numel() == 1, f"Invalid `scale`: {scale}" + dtype = x.dtype + device = x.device + self = torch.Tensor._make_wrapper_subclass( + cls, + x.size(), + dtype=dtype, + device=device, + # strides=x.stride(), + # storage_offset=x.storage_offset(), + # layout=x.layout, + # requires_grad=x.requires_grad, + ) + self._x = x + self._scale = scale + + return self + + # ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22 + __torch_function__ = torch._C._disabled_torch_function_impl + + def __repr__(self): + return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})" + + def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]: + return ["_x", "_scale"], {} + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, torch.Tensor], + metadata: dict[str, Any], + outer_size, + outer_stride, + ) -> ScaleTensorSubclass: + return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"]) + + @staticmethod + def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass: + scale = x.abs().max() + return ScaleTensorSubclass(x, scale) + + @classmethod + def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None): + + def allowed_subclass(typ): + return ( + issubclass(cls, typ) + or issubclass(torch._subclasses.FakeTensor, typ) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ) + ) + + def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): + if isinstance(t, ScaleTensorSubclass): + if t.is_floating_point(): + return t._x * t._scale + else: + return t._x + return t + + if not all(allowed_subclass(t) for t in types): + return NotImplementedError(f"Unsupported types are included: {types}") + + scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass)) + unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs)) + out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs) + if not isinstance(out, torch.Tensor): + return out + else: + return ScaleTensorSubclass(out, scales[0]) + + +# Error message: +# unpack_fn = d.get(inst) +# if unpack_fn is None: +# > raise NotImplementedError(f"Unpacking from {inst} {provenance}") +# E NotImplementedError: Unpacking from LOOKASIDE ProvenanceRecord( +# E i1 = INPUT_FN() +# E i2 = LOAD_ATTR(i1, '__globals__') +# E i3 = BINARY_SUBSCR(i2, 'ScaleTensorSubclass') +# E i4 = LOOKASIDE(i3) +# E ) +# +# thunder/core/jit_ext.py:1503: NotImplementedError +# +# The above exception was the direct cause of the following exception: +# +# def test_subclass_ctor(): +# +# def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: +# return ScaleTensorSubclass(x, scale) +# +# device = torch.device("cuda") +# dtype = torch.float32 +# shape = (2, 2) +# x = make_tensor(shape, device=device, dtype=dtype) +# scale = make_tensor((), device=device, dtype=dtype) +# +# jitted = thunder.jit(f) +# +# expected = f(x, scale) +# > actual = jitted(x, scale) +# +# thunder/tests/test_tensor_subclass.py:104: +# _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +# thunder/__init__.py:768: in wrapped +# return fn(*args, **kwargs) +# thunder/__init__.py:818: in fn_ +# cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) +# thunder/__init__.py:750: in wrapped +# cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) +# thunder/core/langctxs.py:136: in _fn +# result = fn(*args, **kwargs) +# thunder/__init__.py:234: in cache_info_wrapper +# res = fn(*args, **kwargs) +# thunder/__init__.py:522: in get_computation_and_inputs +# jit_results: TraceResults = thunder_general_jit( +# thunder/core/jit_ext.py:1788: in thunder_general_jit +# pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs) +# thunder/core/jit_ext.py:1576: in unpack_inputs +# pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1])) +# thunder/core/jit_ext.py:1576: in +# pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1])) +def test_subclass_ctor(): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + return ScaleTensorSubclass(x, scale) + + device = torch.device("cuda") + dtype = torch.float32 + shape = (2, 2) + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + jitted = thunder.jit(f) + + expected = f(x, scale) + actual = jitted(x, scale) diff --git a/thunder/tests/test_torchao_float8.py b/thunder/tests/test_torchao_float8.py new file mode 100644 index 0000000000..59e4d2909c --- /dev/null +++ b/thunder/tests/test_torchao_float8.py @@ -0,0 +1,28 @@ +import pytest + +pytest.importorskip("torchao") + +import torch +from torchao.float8 import convert_to_float8_training + +import thunder + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9), + reason="Requires cuda of 8.9 or higher", +) +def test_float8_linear(): + model: torch.nn.Module = ( + torch.nn.Sequential( + torch.nn.Linear(2048, 4096), + torch.nn.Linear(4096, 128), + ) + .bfloat16() + .cuda() + ) + convert_to_float8_training(model) + x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) + + jitted = thunder.jit(model) + _ = jitted(x) From c6fce3aa86d3799d77d2c45646c470b8edbf2d39 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 2 Nov 2024 23:06:21 +0900 Subject: [PATCH 2/7] the test case works no `__torch_dispatch__` support at all. Signed-off-by: Masaki Kozuki --- thunder/core/prims.py | 29 ++++++++ thunder/core/proxies.py | 102 +++++++++++++++++++++++--- thunder/executors/torchex.py | 12 +++ thunder/tests/nvfuser_repro.py | 28 ------- thunder/tests/test_tensor_subclass.py | 55 +------------- thunder/tests/test_torchao_float8.py | 28 ------- 6 files changed, 135 insertions(+), 119 deletions(-) delete mode 100644 thunder/tests/nvfuser_repro.py delete mode 100644 thunder/tests/test_torchao_float8.py diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c17a28296c..c9468d653d 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import auto, Enum from numbers import Number from functools import reduce, wraps @@ -77,6 +79,7 @@ def register_method(method_name: str, method: Callable, /) -> None: TupleProxy, AnyProxy, IntegerProxy, + SubclassTensorProxy, ) import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable @@ -272,6 +275,8 @@ class PrimIDs(Enum): COPY_ = auto() # SINK = auto() + # Tensor Subclasses methods + TENSOR_SUBCLASS_CTOR = auto() class OpTags(Enum): @@ -4048,3 +4053,27 @@ def sink_meta(*args, **kwargs): # TODO do we want another tag to remove this after prologue is constructed? sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) + + +def tensor_subclass_ctor_meta( + cls, name, shape, device, dtype, requires_grad, tensors, non_tensors +) -> SubclassTensorProxy: + s = SubclassTensorProxy( + name, + subclass_type=cls, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + tensors=tensors, + non_tensors=non_tensors, + history=[t.history for t in tensors], + ) + return s + + +tensor_subclass_ctor = make_prim( + PrimIDs.TENSOR_SUBCLASS_CTOR, + "tensor_subclass_ctor", + meta=tensor_subclass_ctor_meta, +) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 5118caa3c3..943ec1a2d7 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1886,23 +1886,103 @@ class SubclassTensorProxy(TensorProxy): _subclass_type: torch._C._TensorMeta def __init__(self, *args, **kwargs): - tensors = kwargs.pop("tensors", []) - non_tensors = kwargs.pop("non_tensors", []) + from thunder.core.pytree import tree_flatten + + kwarg_tensors = kwargs.pop("tensors", []) + kwarg_non_tensors = kwargs.pop("non_tensors", []) subclass_type = kwargs.pop("subclass_type", None) - if not hasattr(self, "_name"): + + # If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` + # where `self` should already have gotten its name. + flat_args, spec = tree_flatten((args, kwargs)) + tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) + non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) + has_name_before_init = hasattr(self, "_name") + + is_dunder_init_following_make_wrapper_subclass: bool = False + if tensors: + baseutils.check( + has_name_before_init + and not kwarg_tensors + and not kwarg_non_tensors + and self._subclass_type is not None, + lambda: f"{flat_args=} indicates this instance is created by `torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set", + ) + is_dunder_init_following_make_wrapper_subclass = True + + if not is_dunder_init_following_make_wrapper_subclass: super().__init__(*args, **kwargs) - self._tensors = tensors - self._non_tensors = non_tensors + + if not is_dunder_init_following_make_wrapper_subclass: + self._tensors = kwarg_tensors + self._non_tensors = kwarg_non_tensors self._subclass_type = subclass_type else: - from thunder.core.pytree import tree_flatten + self._tensors = tensors + self._non_tensors = non_tensors - flat_args, spec = tree_flatten((args, kwargs)) - self._tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) - self._non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) - baseutils.check( - self.history is not None, lambda: f"SubclassTensorProxy {self._name} must have its `history` set" + if is_dunder_init_following_make_wrapper_subclass: + from thunder.core import prims + + bsym = prims.tensor_subclass_ctor.bind( + self._subclass_type, + self.name, + self.shape, + self.device, + self.dtype, + self.requires_grad, + self._tensors, + self._non_tensors, + output=self, ) + get_tracectx().add_bound_symbol(bsym) + + def replace(self, **changes): + r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments. + Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. + ``like`` is also a valid keyword and will take metadata from the tensor proxy argument + in preference to the old values but overridable by keyword arguments. + Note that the copy will use the current (environment) tracectx.""" + + like = changes.get("like") + ( + shape, + device, + dtype, + true_dtype, + numel, + ndim, + requires_grad, + grad, + distparallel_type, + thunder_fsdp_padding_size, + ) = _infer_tensor_properties( + like, + changes.get("shape", self._shape if like is None else None), + changes.get("device", self._device if like is None else None), + changes.get("dtype", self._dtype if like is None else None), + changes.get("requires_grad", self._requires_grad if like is None else None), + changes.get("grad", self._grad if like is None else None), + changes.get("distparallel_type", self._distparallel_type if like is None else None), + changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None), + ) + name = changes.get("name", self.name) + history = changes.get("history", self.history) + tags = changes.get("tags", self.tags) + return SubclassTensorProxy( + name=name, + tags=tags, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + distparallel_type=distparallel_type, + thunder_fsdp_padding_size=thunder_fsdp_padding_size, + history=history, + tensors=self._tensors, + non_tensors=self._non_tensors, + subclass_type=self._subclass_type, + ) class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface): diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index afff715728..dc15653f2d 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2174,3 +2174,15 @@ def _shape_impl(t): shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x) _register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable) + + +def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors): + return cls(*tensors, *non_tensors) + + +tensor_subclass_ctor = ex.register_operator( + "tensor_subclass_ctor", + meta=prims.tensor_subclass_ctor, + fn=_tensor_subclass_ctor, +) +_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable) diff --git a/thunder/tests/nvfuser_repro.py b/thunder/tests/nvfuser_repro.py deleted file mode 100644 index f48708dc5e..0000000000 --- a/thunder/tests/nvfuser_repro.py +++ /dev/null @@ -1,28 +0,0 @@ -# CUDA devices: -# 0: NVIDIA RTX 6000 Ada Generation -# torch version: 2.6.0a0+git408fe41 -# cuda version: 12.6 -# nvfuser version: 0.2.11+gitaad7286 -import torch -from nvfuser import FusionDefinition, DataType - - -def nvfuser_fusion_id0(fd: FusionDefinition) -> None: - T0 = fd.define_tensor( - shape=[2, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0] - ) - T1 = fd.define_tensor( - shape=[2, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0] - ) - T2 = fd.ops.add(T0, T1) - fd.add_output(T2) - - -with FusionDefinition() as fd: - nvfuser_fusion_id0(fd) - -inputs = [ - torch.randn(4, dtype=torch.float32, device="cuda:0").as_strided((2, 2), (2, 1)), - torch.randn(4, dtype=torch.float32, device="cuda:0").as_strided((2, 2), (2, 1)), -] -fd.execute(inputs) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 0481885044..6459fdec21 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -87,58 +87,7 @@ def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): return ScaleTensorSubclass(out, scales[0]) -# Error message: -# unpack_fn = d.get(inst) -# if unpack_fn is None: -# > raise NotImplementedError(f"Unpacking from {inst} {provenance}") -# E NotImplementedError: Unpacking from LOOKASIDE ProvenanceRecord( -# E i1 = INPUT_FN() -# E i2 = LOAD_ATTR(i1, '__globals__') -# E i3 = BINARY_SUBSCR(i2, 'ScaleTensorSubclass') -# E i4 = LOOKASIDE(i3) -# E ) -# -# thunder/core/jit_ext.py:1503: NotImplementedError -# -# The above exception was the direct cause of the following exception: -# -# def test_subclass_ctor(): -# -# def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: -# return ScaleTensorSubclass(x, scale) -# -# device = torch.device("cuda") -# dtype = torch.float32 -# shape = (2, 2) -# x = make_tensor(shape, device=device, dtype=dtype) -# scale = make_tensor((), device=device, dtype=dtype) -# -# jitted = thunder.jit(f) -# -# expected = f(x, scale) -# > actual = jitted(x, scale) -# -# thunder/tests/test_tensor_subclass.py:104: -# _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -# thunder/__init__.py:768: in wrapped -# return fn(*args, **kwargs) -# thunder/__init__.py:818: in fn_ -# cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) -# thunder/__init__.py:750: in wrapped -# cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) -# thunder/core/langctxs.py:136: in _fn -# result = fn(*args, **kwargs) -# thunder/__init__.py:234: in cache_info_wrapper -# res = fn(*args, **kwargs) -# thunder/__init__.py:522: in get_computation_and_inputs -# jit_results: TraceResults = thunder_general_jit( -# thunder/core/jit_ext.py:1788: in thunder_general_jit -# pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs) -# thunder/core/jit_ext.py:1576: in unpack_inputs -# pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1])) -# thunder/core/jit_ext.py:1576: in -# pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1])) -def test_subclass_ctor(): +def test_func_of_subclass_ctor_wrapper(): def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: return ScaleTensorSubclass(x, scale) @@ -153,3 +102,5 @@ def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: expected = f(x, scale) actual = jitted(x, scale) + assert type(expected) is type(actual) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) diff --git a/thunder/tests/test_torchao_float8.py b/thunder/tests/test_torchao_float8.py deleted file mode 100644 index 59e4d2909c..0000000000 --- a/thunder/tests/test_torchao_float8.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -pytest.importorskip("torchao") - -import torch -from torchao.float8 import convert_to_float8_training - -import thunder - - -@pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9), - reason="Requires cuda of 8.9 or higher", -) -def test_float8_linear(): - model: torch.nn.Module = ( - torch.nn.Sequential( - torch.nn.Linear(2048, 4096), - torch.nn.Linear(4096, 128), - ) - .bfloat16() - .cuda() - ) - convert_to_float8_training(model) - x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) - - jitted = thunder.jit(model) - _ = jitted(x) From af323615d6ba19cc8b0c1a2b9de2d4d47c50b2ec Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 3 Nov 2024 15:35:01 +0900 Subject: [PATCH 3/7] attribute access to subclass proxy seems functioning somehow, apparently Signed-off-by: Masaki Kozuki --- thunder/tests/test_tensor_subclass.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 6459fdec21..1bf45a0ef9 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -4,6 +4,7 @@ import torch import thunder +from thunder.tests.framework import instantiate, NOTHING from thunder.tests.make_tensor import make_tensor if TYPE_CHECKING: @@ -87,20 +88,33 @@ def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): return ScaleTensorSubclass(out, scales[0]) -def test_func_of_subclass_ctor_wrapper(): +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_of_subclass_ctor_wrapper(executor, device, _): def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: - return ScaleTensorSubclass(x, scale) + y = ScaleTensorSubclass(x, scale) + return y + + jitted = executor.make_callable(f) - device = torch.device("cuda") dtype = torch.float32 shape = (2, 2) x = make_tensor(shape, device=device, dtype=dtype) scale = make_tensor((), device=device, dtype=dtype) - jitted = thunder.jit(f) + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def f(x: torch.Tensor, scale: torch.Tensor): + y = ScaleTensorSubclass(x, scale) + z = ScaleTensorSubclass(y._x, y._scale) + return z + + jitted = executor.make_callable(f) expected = f(x, scale) actual = jitted(x, scale) - assert type(expected) is type(actual) torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) From 546fb106ae650d2e86088a9f39557774724f169d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 3 Nov 2024 15:44:11 +0900 Subject: [PATCH 4/7] simplify if-else in `SubclassTensorProxy.__init__` Signed-off-by: Masaki Kozuki --- thunder/core/proxies.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 943ec1a2d7..80d60a4894 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1906,24 +1906,26 @@ def __init__(self, *args, **kwargs): and not kwarg_tensors and not kwarg_non_tensors and self._subclass_type is not None, - lambda: f"{flat_args=} indicates this instance is created by `torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set", + lambda: ( + f"{flat_args=} indicates this instance is created by" + "`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set" + ), ) is_dunder_init_following_make_wrapper_subclass = True if not is_dunder_init_following_make_wrapper_subclass: super().__init__(*args, **kwargs) - if not is_dunder_init_following_make_wrapper_subclass: self._tensors = kwarg_tensors self._non_tensors = kwarg_non_tensors self._subclass_type = subclass_type else: - self._tensors = tensors - self._non_tensors = non_tensors - - if is_dunder_init_following_make_wrapper_subclass: + # TODO(crcrpar): Think about materializing `self` so that we can + # call `__tensor_init__` to know each attribute names. from thunder.core import prims + self._tensors = tensors + self._non_tensors = non_tensors bsym = prims.tensor_subclass_ctor.bind( self._subclass_type, self.name, From 958d235b4aa4f04edae86fd9b5bdb3cbdf9a504f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 3 Nov 2024 16:08:16 +0900 Subject: [PATCH 5/7] stricter type check of tensors Signed-off-by: Masaki Kozuki --- thunder/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index c09c1cc9b5..104af7734f 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -369,7 +369,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: data_ptr_to_tensor_group_index = {} tensor_group_index_to_tensor_indices = defaultdict(list) for idx, t in enumerate(flat_args): - if pytorch.is_tensor(t) and t.layout == pytorch.strided: + if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided: data_ptr = t.untyped_storage().data_ptr() if data_ptr not in data_ptr_to_tensor_group_index: data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) From 9900d33dd559b9a279a010035d496495a18d7366 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 5 Nov 2024 16:37:24 +0900 Subject: [PATCH 6/7] support `MySubclass(...)` called inside of `torch.autograd.Function` Signed-off-by: Masaki Kozuki --- thunder/core/proxies.py | 3 +- thunder/tests/test_tensor_subclass.py | 64 ++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 80d60a4894..eb605dc2e8 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1937,7 +1937,8 @@ def __init__(self, *args, **kwargs): self._non_tensors, output=self, ) - get_tracectx().add_bound_symbol(bsym) + current_trace = get_tracectx() + current_trace.scopes[-1].append(bsym) def replace(self, **changes): r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments. diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 1bf45a0ef9..3652965a83 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -4,13 +4,43 @@ import torch import thunder -from thunder.tests.framework import instantiate, NOTHING +from thunder.tests.framework import instantiate from thunder.tests.make_tensor import make_tensor if TYPE_CHECKING: from typing import Any +@torch._dynamo.allow_in_graph +class EncapsulateXandScale(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, scale: torch.Tensor): + return ScaleTensorSubclass(x, scale) + + @staticmethod + def backward(ctx, grad): + return grad, None + + +def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass: + return EncapsulateXandScale.apply(x, scale) + + +@torch._dynamo.allow_in_graph +class ToScaleTensorSubclass(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return ScaleTensorSubclass.from_tensor(x) + + @staticmethod + def backward(ctx, grad): + return grad + + +def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass: + return ToScaleTensorSubclass.apply(x) + + class ScaleTensorSubclass(torch.Tensor): _x: torch.Tensor _scale: torch.Tensor @@ -118,3 +148,35 @@ def f(x: torch.Tensor, scale: torch.Tensor): expected = f(x, scale) actual = jitted(x, scale) torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_calling_converter(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = encapsulate_x_and_scale(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def g(x: torch.Tensor) -> ScaleTensorSubclass: + y = to_scale_tensor_subclass(x) + return y + + jitted = thunder.jit(g) + x = make_tensor(shape, device=device, dtype=dtype) + + expected = g(x) + actual = jitted(x) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) From 2e4ecad42e0e68efad7dfcf00e32c6981540fb7d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 5 Nov 2024 16:49:05 +0900 Subject: [PATCH 7/7] explanation Signed-off-by: Masaki Kozuki --- thunder/core/proxies.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index eb605dc2e8..8cf179e263 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1937,8 +1937,19 @@ def __init__(self, *args, **kwargs): self._non_tensors, output=self, ) + # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)` + # inside of it either directly or indirectly: indirect way is to call it through + # a custom `torch.autograd.Function` as in + # https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209. + # If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical, + # but not, otherwise. As [the lookasdie of `torch.autograd.Function`]( + # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655) + # puts the temporary scope to the current trace. current_trace = get_tracectx() - current_trace.scopes[-1].append(bsym) + if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]): + current_trace.add_bound_symbol(bsym) + else: + cur_tail_scope.append(bsym) def replace(self, **changes): r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments.