Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Subclasses] Support func calling only Subclass(...) #1393

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please submit this fix with a test in a separate pull request?

Copy link
Collaborator Author

@crcrpar crcrpar Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That pull request would need a subclass in the test then I'm not quite convinced by the option

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do subclasses appear here? Do all subclasses have the actual torch.Tensor type?

data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
Expand Down
37 changes: 37 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
NumberProxy,
StringProxy,
TensorProxy,
SubclassTensorProxy,
FutureTensorProxy,
make_proxy_name,
Variable,
Expand Down Expand Up @@ -757,6 +758,42 @@ def grad_transform(*args, **kwargs):
return forward_result


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be a jit lookaside? Can the implementation be moved to thunder/torch/__init__.py if @torchsymbol is used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather want to hide this method as possible. so making it a lookaside feels more right than a torchsymbol

def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)
Comment on lines +779 to +783
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be clear about ignored args


subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


@register_general_jit_lookaside(torch.autocast.__enter__)
def autocast_enter(autocast_obj):
unwrap_autocast_obj = unwrap(autocast_obj)
Expand Down
29 changes: 29 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from enum import auto, Enum
from numbers import Number
from functools import reduce, wraps
Expand Down Expand Up @@ -77,6 +79,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
SubclassTensorProxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
Expand Down Expand Up @@ -272,6 +275,8 @@ class PrimIDs(Enum):
COPY_ = auto()
#
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -4048,3 +4053,27 @@ def sink_meta(*args, **kwargs):

# TODO do we want another tag to remove this after prologue is constructed?
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))


def tensor_subclass_ctor_meta(
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors
) -> SubclassTensorProxy:
s = SubclassTensorProxy(
name,
subclass_type=cls,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

history should be able to get better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of history?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dunno, something disallows history being empty/none

)
return s


tensor_subclass_ctor = make_prim(
PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
)
119 changes: 119 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,6 +1880,125 @@ def real(self):
return method(self)


class SubclassTensorProxy(TensorProxy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be able to express __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]] (and __tensor_unflatten__(inner_tensors: dict[str, Tensor], metadata: dict[str, Any], outer_size, outer_stride) -> MySubclass).
For it to happen, somewhere I have to give instances of this class the attribute names of tensors and non-tensor values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you add a comment describing what this class is intended for?

_tensors: list[TensorProxy]
_non_tensors: list[Any]
_subclass_type: torch._C._TensorMeta

def __init__(self, *args, **kwargs):
from thunder.core.pytree import tree_flatten

kwarg_tensors = kwargs.pop("tensors", [])
Copy link
Collaborator

@mruberry mruberry Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__(self, *args, tensors=[], non_tensors=[], subclass_type=None, **kwargs)

?

kwarg_non_tensors = kwargs.pop("non_tensors", [])
subclass_type = kwargs.pop("subclass_type", None)

# If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you elaborate on this comment?

# where `self` should already have gotten its name.
flat_args, spec = tree_flatten((args, kwargs))
tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args))
non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args))
has_name_before_init = hasattr(self, "_name")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If thunder sees a tensor wrapper subclass of MySubclass(...) that has its own dunder new calling _make_wrapper_subclass in a function thunder's tracing, the lookaside of https://github.com/Lightning-AI/lightning-thunder/pull/1393/files#diff-3d1ea50ad3b0e3ad6fc369f91a7e42011d1d33d770ce25f800637c99de85f4b5R762 creates an instance of SubclassTensorProxy, then the dunder init of that instance is called, not the dunder init of MySubclass instance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm still confused about how and why the proxy class is entangled with the actual subclass


is_dunder_init_following_make_wrapper_subclass: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if tensors:
baseutils.check(
has_name_before_init
and not kwarg_tensors
and not kwarg_non_tensors
and self._subclass_type is not None,
lambda: (
f"{flat_args=} indicates this instance is created by"
"`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set"
),
)
is_dunder_init_following_make_wrapper_subclass = True

if not is_dunder_init_following_make_wrapper_subclass:
super().__init__(*args, **kwargs)

self._tensors = kwarg_tensors
self._non_tensors = kwarg_non_tensors
self._subclass_type = subclass_type
else:
# TODO(crcrpar): Think about materializing `self` so that we can
# call `__tensor_init__` to know each attribute names.
from thunder.core import prims

self._tensors = tensors
self._non_tensors = non_tensors
bsym = prims.tensor_subclass_ctor.bind(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. Why does this class either create a new proxy or change the actual trace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this path is happening while tracing and the call of Proxy.__init__ is not recorded by default in a trace so it's necessary to deliberately register a boundsymbol to the currently active scope

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK; but why does proxy creation get recorded into the trace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a function creates a subclass inside it then shouldn't a trace of it have a BoundSymbol to represent it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to put tensor subclass creation into a trace (although I'm curious if we can dce the creation if the subclass is just flattened and used once later), but I'm not sure why the operator that creates the actual tensor subclass is also a constructor for the proxy. The existing tensor proxies don't entangle their tensor factory methods with the creation of the proxy, for example

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the actual call to _make_wrapper_subclass tell us anything about the subclass

torch.Tensor._make_wrapper_subclass takes cls. The lookaside could let us tell things other than actual tensors and other values to be registered to the output of the method.

I guess my question, then, is if we can also just call ComplexTensor like a practitioner would?

I don't get where this call of ComplexTensor would be. Could you elaborate on it?

Then the meta for ComplexTensor could return a SubclassTensorProxy or even a ComplexTensorSubclassProxy.

How and when would we register ComplexTensorSubclassProxy? Would this indicate that Thunder is going to detect a subclass and define a proxy for it dynamically?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! What if we detected that a subclass constructor was called and dynamically created a symbol for the constructor and the corresponding proxy?

The call to ComplexTensor could be in the resulting Thunder program, like other bound symbols. Or maybe we want to represent ComplexTensor construction as an unflatten call, like ComplexTensor.__tensor_unflatten__? Both seem reasonable to me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I do that dynamically, I want that subclass proxy to correctly capture the result type of ops with that subclass. I however don't see a clear path to infer the result type of ops with subclasses, other than we do what AOT Autograd does with torch.fx tracing. Thus I don't think dynamic registration is that different from what I have in this PR.

Representing subclass construction as unflatten call would then call __init__ of that subclass, so I'm not following the point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal is to divorce the proxy from the actual runtime object to simplify the code and its concepts.

How else would we infer the decomposition of operations called on a tensor subclass, and if the result is also a member of the subclass?

I understand that init gets called on the object at runtime, of course. The goal is not to someone circumvent the proper construction of the object at runtime (unless it can be elided for performance).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm not understanding this well, but I'm a bit worried that we don't have well-defined semantics here and that the traces are not representing what's up.

We should absolutely know when we want to construct and use the subclass object (if we return it or something wants the subclass object) and when we don't (all other cases) and should also represent what will be the compute in the trace.

We also want to minimize admin overhead during the compute, which we are not doing a great job today (having looked at wall clock vs. GPU self time for Llama 1b today), so adding the overhead of dealing with subclasses at compute time should likely be minimized.

self._subclass_type,
self.name,
self.shape,
self.device,
self.dtype,
self.requires_grad,
self._tensors,
self._non_tensors,
output=self,
)
# NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
# inside of it either directly or indirectly: indirect way is to call it through
# a custom `torch.autograd.Function` as in
# https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209.
# If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical,
# but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
# https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
# puts the temporary scope to the current trace.
current_trace = get_tracectx()
if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
current_trace.add_bound_symbol(bsym)
else:
cur_tail_scope.append(bsym)

def replace(self, **changes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a replace function needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I don't want .replace(...) call to replace subclass tensor proxy with tensor proxy

r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments.
Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``.
``like`` is also a valid keyword and will take metadata from the tensor proxy argument
in preference to the old values but overridable by keyword arguments.
Note that the copy will use the current (environment) tracectx."""

like = changes.get("like")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If changes doesn't have like as a key won't this throw a key error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If key is not found then the default value, in this case, None is returned. https://docs.python.org/3/library/stdtypes.html#dict.get

(
shape,
device,
dtype,
true_dtype,
numel,
ndim,
requires_grad,
grad,
distparallel_type,
thunder_fsdp_padding_size,
) = _infer_tensor_properties(
like,
changes.get("shape", self._shape if like is None else None),
Copy link
Collaborator

@mruberry mruberry Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

*[changes.get(key, None) for key in ('shape', 'device', ...)]

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied these lines from TensorProxy.replace and I'm lazy enough not to do that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But doesn't the current code mean that if like is specified then the shape of the like tensor is overriden with the shape of the current tensor? Is that what's intended?

changes.get("device", self._device if like is None else None),
changes.get("dtype", self._dtype if like is None else None),
changes.get("requires_grad", self._requires_grad if like is None else None),
changes.get("grad", self._grad if like is None else None),
changes.get("distparallel_type", self._distparallel_type if like is None else None),
changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None),
)
name = changes.get("name", self.name)
history = changes.get("history", self.history)
tags = changes.get("tags", self.tags)
return SubclassTensorProxy(
name=name,
tags=tags,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
distparallel_type=distparallel_type,
thunder_fsdp_padding_size=thunder_fsdp_padding_size,
history=history,
tensors=self._tensors,
non_tensors=self._non_tensors,
subclass_type=self._subclass_type,
)


class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface):
def __init__(
self,
Expand Down
12 changes: 12 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,3 +2174,15 @@ def _shape_impl(t):

shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x)
_register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable)


def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors):
return cls(*tensors, *non_tensors)


tensor_subclass_ctor = ex.register_operator(
"tensor_subclass_ctor",
meta=prims.tensor_subclass_ctor,
fn=_tensor_subclass_ctor,
)
_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable)
Loading
Loading