-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tensor Subclasses] Support func calling only Subclass(...)
#1393
base: main
Are you sure you want to change the base?
Changes from all commits
e76a98c
c6fce3a
af32361
546fb10
958d235
9900d33
2e4ecad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ | |
NumberProxy, | ||
StringProxy, | ||
TensorProxy, | ||
SubclassTensorProxy, | ||
FutureTensorProxy, | ||
make_proxy_name, | ||
Variable, | ||
|
@@ -757,6 +758,42 @@ def grad_transform(*args, **kwargs): | |
return forward_result | ||
|
||
|
||
@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it need to be a jit lookaside? Can the implementation be moved to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I rather want to hide this method as possible. so making it a lookaside feels more right than a torchsymbol |
||
def _make_wrapper_subclass( | ||
cls: torch._C._TensorMeta, | ||
size: Sequence[int], | ||
strides: Sequence[int] | None = None, | ||
storage_offset: int | None = None, | ||
memory_format: torch.memory_format | None = None, | ||
dtype: torch.dtype | None = None, | ||
layout: torch.layout | None = torch.strided, | ||
device: torch.device | None = None, | ||
pin_memory: bool = False, | ||
requires_grad: bool = False, | ||
dispatch_sizes_strides_policy: str | None = None, | ||
dispatch_device: bool = False, | ||
dispatch_layout: bool = False, | ||
_extra_dispatch_keys: torch.DispatchKeySet | None = None, | ||
storage_size: int | None = None, | ||
): | ||
ucls = unwrap(cls) | ||
usize = unwrap(size) | ||
udtype = unwrap(dtype) | ||
udevice = unwrap(device) | ||
urequires_grad = unwrap(requires_grad) | ||
Comment on lines
+779
to
+783
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be clear about ignored args |
||
|
||
subclass = SubclassTensorProxy( | ||
None, | ||
shape=usize, | ||
device=udevice, | ||
dtype=udtype, | ||
requires_grad=urequires_grad, | ||
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), | ||
subclass_type=ucls, | ||
) | ||
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance])) | ||
|
||
|
||
@register_general_jit_lookaside(torch.autocast.__enter__) | ||
def autocast_enter(autocast_obj): | ||
unwrap_autocast_obj = unwrap(autocast_obj) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
from enum import auto, Enum | ||
from numbers import Number | ||
from functools import reduce, wraps | ||
|
@@ -77,6 +79,7 @@ def register_method(method_name: str, method: Callable, /) -> None: | |
TupleProxy, | ||
AnyProxy, | ||
IntegerProxy, | ||
SubclassTensorProxy, | ||
) | ||
import thunder.core.codeutils as codeutils | ||
from thunder.core.codeutils import Printable | ||
|
@@ -272,6 +275,8 @@ class PrimIDs(Enum): | |
COPY_ = auto() | ||
# | ||
SINK = auto() | ||
# Tensor Subclasses methods | ||
TENSOR_SUBCLASS_CTOR = auto() | ||
|
||
|
||
class OpTags(Enum): | ||
|
@@ -4048,3 +4053,27 @@ def sink_meta(*args, **kwargs): | |
|
||
# TODO do we want another tag to remove this after prologue is constructed? | ||
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) | ||
|
||
|
||
def tensor_subclass_ctor_meta( | ||
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors | ||
) -> SubclassTensorProxy: | ||
s = SubclassTensorProxy( | ||
name, | ||
subclass_type=cls, | ||
shape=shape, | ||
device=device, | ||
dtype=dtype, | ||
requires_grad=requires_grad, | ||
tensors=tensors, | ||
non_tensors=non_tensors, | ||
history=[t.history for t in tensors], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the purpose of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dunno, something disallows history being empty/none |
||
) | ||
return s | ||
|
||
|
||
tensor_subclass_ctor = make_prim( | ||
PrimIDs.TENSOR_SUBCLASS_CTOR, | ||
"tensor_subclass_ctor", | ||
meta=tensor_subclass_ctor_meta, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1880,6 +1880,125 @@ def real(self): | |
return method(self) | ||
|
||
|
||
class SubclassTensorProxy(TensorProxy): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be able to express There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you add a comment describing what this class is intended for? |
||
_tensors: list[TensorProxy] | ||
_non_tensors: list[Any] | ||
_subclass_type: torch._C._TensorMeta | ||
|
||
def __init__(self, *args, **kwargs): | ||
from thunder.core.pytree import tree_flatten | ||
|
||
kwarg_tensors = kwargs.pop("tensors", []) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
? |
||
kwarg_non_tensors = kwargs.pop("non_tensors", []) | ||
subclass_type = kwargs.pop("subclass_type", None) | ||
|
||
# If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you elaborate on this comment? |
||
# where `self` should already have gotten its name. | ||
flat_args, spec = tree_flatten((args, kwargs)) | ||
tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) | ||
non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) | ||
has_name_before_init = hasattr(self, "_name") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can this happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If thunder sees a tensor wrapper subclass of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I'm still confused about how and why the proxy class is entangled with the actual subclass |
||
|
||
is_dunder_init_following_make_wrapper_subclass: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's going on here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
if tensors: | ||
baseutils.check( | ||
has_name_before_init | ||
and not kwarg_tensors | ||
and not kwarg_non_tensors | ||
and self._subclass_type is not None, | ||
lambda: ( | ||
f"{flat_args=} indicates this instance is created by" | ||
"`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set" | ||
), | ||
) | ||
is_dunder_init_following_make_wrapper_subclass = True | ||
|
||
if not is_dunder_init_following_make_wrapper_subclass: | ||
super().__init__(*args, **kwargs) | ||
|
||
self._tensors = kwarg_tensors | ||
self._non_tensors = kwarg_non_tensors | ||
self._subclass_type = subclass_type | ||
else: | ||
# TODO(crcrpar): Think about materializing `self` so that we can | ||
# call `__tensor_init__` to know each attribute names. | ||
from thunder.core import prims | ||
|
||
self._tensors = tensors | ||
self._non_tensors = non_tensors | ||
bsym = prims.tensor_subclass_ctor.bind( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is interesting. Why does this class either create a new proxy or change the actual trace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this path is happening while tracing and the call of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK; but why does proxy creation get recorded into the trace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if a function creates a subclass inside it then shouldn't a trace of it have a BoundSymbol to represent it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's OK to put tensor subclass creation into a trace (although I'm curious if we can dce the creation if the subclass is just flattened and used once later), but I'm not sure why the operator that creates the actual tensor subclass is also a constructor for the proxy. The existing tensor proxies don't entangle their tensor factory methods with the creation of the proxy, for example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't get where this call of
How and when would we register There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! What if we detected that a subclass constructor was called and dynamically created a symbol for the constructor and the corresponding proxy? The call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I do that dynamically, I want that subclass proxy to correctly capture the result type of ops with that subclass. I however don't see a clear path to infer the result type of ops with subclasses, other than we do what AOT Autograd does with torch.fx tracing. Thus I don't think dynamic registration is that different from what I have in this PR. Representing subclass construction as unflatten call would then call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The goal is to divorce the proxy from the actual runtime object to simplify the code and its concepts. How else would we infer the decomposition of operations called on a tensor subclass, and if the result is also a member of the subclass? I understand that init gets called on the object at runtime, of course. The goal is not to someone circumvent the proper construction of the object at runtime (unless it can be elided for performance). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I'm not understanding this well, but I'm a bit worried that we don't have well-defined semantics here and that the traces are not representing what's up. We should absolutely know when we want to construct and use the subclass object (if we return it or something wants the subclass object) and when we don't (all other cases) and should also represent what will be the compute in the trace. We also want to minimize admin overhead during the compute, which we are not doing a great job today (having looked at wall clock vs. GPU self time for Llama 1b today), so adding the overhead of dealing with subclasses at compute time should likely be minimized. |
||
self._subclass_type, | ||
self.name, | ||
self.shape, | ||
self.device, | ||
self.dtype, | ||
self.requires_grad, | ||
self._tensors, | ||
self._non_tensors, | ||
output=self, | ||
) | ||
# NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)` | ||
# inside of it either directly or indirectly: indirect way is to call it through | ||
# a custom `torch.autograd.Function` as in | ||
# https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209. | ||
# If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical, | ||
# but not, otherwise. As [the lookasdie of `torch.autograd.Function`]( | ||
# https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655) | ||
# puts the temporary scope to the current trace. | ||
current_trace = get_tracectx() | ||
if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]): | ||
current_trace.add_bound_symbol(bsym) | ||
else: | ||
cur_tail_scope.append(bsym) | ||
|
||
def replace(self, **changes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is a replace function needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because I don't want |
||
r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments. | ||
Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. | ||
``like`` is also a valid keyword and will take metadata from the tensor proxy argument | ||
in preference to the old values but overridable by keyword arguments. | ||
Note that the copy will use the current (environment) tracectx.""" | ||
|
||
like = changes.get("like") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If key is not found then the default value, in this case, |
||
( | ||
shape, | ||
device, | ||
dtype, | ||
true_dtype, | ||
numel, | ||
ndim, | ||
requires_grad, | ||
grad, | ||
distparallel_type, | ||
thunder_fsdp_padding_size, | ||
) = _infer_tensor_properties( | ||
like, | ||
changes.get("shape", self._shape if like is None else None), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied these lines from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But doesn't the current code mean that if |
||
changes.get("device", self._device if like is None else None), | ||
changes.get("dtype", self._dtype if like is None else None), | ||
changes.get("requires_grad", self._requires_grad if like is None else None), | ||
changes.get("grad", self._grad if like is None else None), | ||
changes.get("distparallel_type", self._distparallel_type if like is None else None), | ||
changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None), | ||
) | ||
name = changes.get("name", self.name) | ||
history = changes.get("history", self.history) | ||
tags = changes.get("tags", self.tags) | ||
return SubclassTensorProxy( | ||
name=name, | ||
tags=tags, | ||
shape=shape, | ||
device=device, | ||
dtype=dtype, | ||
requires_grad=requires_grad, | ||
distparallel_type=distparallel_type, | ||
thunder_fsdp_padding_size=thunder_fsdp_padding_size, | ||
history=history, | ||
tensors=self._tensors, | ||
non_tensors=self._non_tensors, | ||
subclass_type=self._subclass_type, | ||
) | ||
|
||
|
||
class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface): | ||
def __init__( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please submit this fix with a test in a separate pull request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That pull request would need a subclass in the test then I'm not quite convinced by the option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do subclasses appear here? Do all subclasses have the actual torch.Tensor type?