-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tensor Subclasses] Support func calling only Subclass(...)
#1393
base: main
Are you sure you want to change the base?
Conversation
Subclass(...)
Subclass(...)
ucls = unwrap(cls) | ||
usize = unwrap(size) | ||
udtype = unwrap(dtype) | ||
udevice = unwrap(device) | ||
urequires_grad = unwrap(requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be clear about ignored args
requires_grad=requires_grad, | ||
tensors=tensors, | ||
non_tensors=non_tensors, | ||
history=[t.history for t in tensors], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
history
should be able to get better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of history
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dunno, something disallows history being empty/none
@@ -1880,6 +1880,111 @@ def real(self): | |||
return method(self) | |||
|
|||
|
|||
class SubclassTensorProxy(TensorProxy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be able to express __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]
(and __tensor_unflatten__(inner_tensors: dict[str, Tensor], metadata: dict[str, Any], outer_size, outer_stride) -> MySubclass
).
For it to happen, somewhere I have to give instances of this class the attribute names of tensors and non-tensor values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you add a comment describing what this class is intended for?
0b69d52
to
9d77226
Compare
assert scale.numel() == 1, f"Invalid `scale`: {scale}" | ||
dtype = x.dtype | ||
device = x.device | ||
self = torch.Tensor._make_wrapper_subclass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to use make_wrapper_subclass
when requires_grad=False
?
Here the behavior is different depending on the requires_grad value:
https://github.com/albanD/subclass_zoo/blob/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/base_tensor.py#L12-L15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The link uses _make_subclass
not _make_wrapper_subclass
and the last update is 2 years ago, so it doesn't sound convincing to me
@@ -743,6 +744,42 @@ def grad_transform(*args, **kwargs): | |||
return forward_result | |||
|
|||
|
|||
@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it need to be a jit lookaside? Can the implementation be moved to thunder/torch/__init__.py
if @torchsymbol
is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rather want to hide this method as possible. so making it a lookaside feels more right than a torchsymbol
@@ -369,7 +369,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: | |||
data_ptr_to_tensor_group_index = {} | |||
tensor_group_index_to_tensor_indices = defaultdict(list) | |||
for idx, t in enumerate(flat_args): | |||
if pytorch.is_tensor(t) and t.layout == pytorch.strided: | |||
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please submit this fix with a test in a separate pull request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That pull request would need a subclass in the test then I'm not quite convinced by the option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do subclasses appear here? Do all subclasses have the actual torch.Tensor type?
21c2af8
to
11fea26
Compare
…ass` lookaside Signed-off-by: Masaki Kozuki <[email protected]>
no `__torch_dispatch__` support at all. Signed-off-by: Masaki Kozuki <[email protected]>
somehow, apparently Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
11fea26
to
2e4ecad
Compare
kwarg_non_tensors = kwargs.pop("non_tensors", []) | ||
subclass_type = kwargs.pop("subclass_type", None) | ||
|
||
# If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you elaborate on this comment?
non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) | ||
has_name_before_init = hasattr(self, "_name") | ||
|
||
is_dunder_init_following_make_wrapper_subclass: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
def __init__(self, *args, **kwargs): | ||
from thunder.core.pytree import tree_flatten | ||
|
||
kwarg_tensors = kwargs.pop("tensors", []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__init__(self, *args, tensors=[], non_tensors=[], subclass_type=None, **kwargs)
?
flat_args, spec = tree_flatten((args, kwargs)) | ||
tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) | ||
non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) | ||
has_name_before_init = hasattr(self, "_name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can this happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If thunder sees a tensor wrapper subclass of MySubclass(...)
that has its own dunder new calling _make_wrapper_subclass
in a function thunder's tracing, the lookaside of https://github.com/Lightning-AI/lightning-thunder/pull/1393/files#diff-3d1ea50ad3b0e3ad6fc369f91a7e42011d1d33d770ce25f800637c99de85f4b5R762 creates an instance of SubclassTensorProxy, then the dunder init of that instance is called, not the dunder init of MySubclass
instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I'm still confused about how and why the proxy class is entangled with the actual subclass
|
||
self._tensors = tensors | ||
self._non_tensors = non_tensors | ||
bsym = prims.tensor_subclass_ctor.bind( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting. Why does this class either create a new proxy or change the actual trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this path is happening while tracing and the call of Proxy.__init__
is not recorded by default in a trace so it's necessary to deliberately register a boundsymbol to the currently active scope
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK; but why does proxy creation get recorded into the trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a function creates a subclass inside it then shouldn't a trace of it have a BoundSymbol to represent it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's OK to put tensor subclass creation into a trace (although I'm curious if we can dce the creation if the subclass is just flattened and used once later), but I'm not sure why the operator that creates the actual tensor subclass is also a constructor for the proxy. The existing tensor proxies don't entangle their tensor factory methods with the creation of the proxy, for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the actual call to
_make_wrapper_subclass
tell us anything about the subclass
torch.Tensor._make_wrapper_subclass
takes cls
. The lookaside could let us tell things other than actual tensors and other values to be registered to the output of the method.
I guess my question, then, is if we can also just call
ComplexTensor
like a practitioner would?
I don't get where this call of ComplexTensor
would be. Could you elaborate on it?
Then the meta for
ComplexTensor
could return aSubclassTensorProxy
or even aComplexTensorSubclassProxy
.
How and when would we register ComplexTensorSubclassProxy
? Would this indicate that Thunder is going to detect a subclass and define a proxy for it dynamically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! What if we detected that a subclass constructor was called and dynamically created a symbol for the constructor and the corresponding proxy?
The call to ComplexTensor
could be in the resulting Thunder program, like other bound symbols. Or maybe we want to represent ComplexTensor
construction as an unflatten call, like ComplexTensor.__tensor_unflatten__
? Both seem reasonable to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I do that dynamically, I want that subclass proxy to correctly capture the result type of ops with that subclass. I however don't see a clear path to infer the result type of ops with subclasses, other than we do what AOT Autograd does with torch.fx tracing. Thus I don't think dynamic registration is that different from what I have in this PR.
Representing subclass construction as unflatten call would then call __init__
of that subclass, so I'm not following the point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal is to divorce the proxy from the actual runtime object to simplify the code and its concepts.
How else would we infer the decomposition of operations called on a tensor subclass, and if the result is also a member of the subclass?
I understand that init gets called on the object at runtime, of course. The goal is not to someone circumvent the proper construction of the object at runtime (unless it can be elided for performance).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm not understanding this well, but I'm a bit worried that we don't have well-defined semantics here and that the traces are not representing what's up.
We should absolutely know when we want to construct and use the subclass object (if we return it or something wants the subclass object) and when we don't (all other cases) and should also represent what will be the compute in the trace.
We also want to minimize admin overhead during the compute, which we are not doing a great job today (having looked at wall clock vs. GPU self time for Llama 1b today), so adding the overhead of dealing with subclasses at compute time should likely be minimized.
else: | ||
cur_tail_scope.append(bsym) | ||
|
||
def replace(self, **changes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is a replace function needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because I don't want .replace(...)
call to replace subclass tensor proxy with tensor proxy
in preference to the old values but overridable by keyword arguments. | ||
Note that the copy will use the current (environment) tracectx.""" | ||
|
||
like = changes.get("like") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If changes
doesn't have like
as a key won't this throw a key error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If key is not found then the default value, in this case, None
is returned. https://docs.python.org/3/library/stdtypes.html#dict.get
thunder_fsdp_padding_size, | ||
) = _infer_tensor_properties( | ||
like, | ||
changes.get("shape", self._shape if like is None else None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
*[changes.get(key, None) for key in ('shape', 'device', ...)]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied these lines from TensorProxy.replace
and I'm lazy enough not to do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But doesn't the current code mean that if like
is specified then the shape of the like tensor is overriden with the shape of the current tensor? Is that what's intended?
@instantiate( | ||
dtypes=(thunder.core.dtypes.float32,), | ||
) | ||
def test_func_of_subclass_ctor_wrapper(executor, device, _): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these tests just check that the tensor subclass can be constructed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this PR do?
The scope of this PR is Tensor Subclasses that call
torch.Tensor._make_wrapper_subclass
in their dunder new (and define__torch_dispatch__
,__tensor_flatten__
, and__tensor_unflatten__
).This PR adds a new proxy for such tensor subclasses and implements a lookaside for
_make_wrapper_subclass
that returns an instance of the new proxy.MySubclass(...)
callsMySubclass.__new__(cls, ...)
before calling__init__(...)
on the return value of the dunder new.Since this PR has the lookaside and it returns an instance of a proxy, not
MySubclass
, the__init__
of the new proxy is called.This is the reason the new proxy has if-else branches inside of its dunder new.
Caveat: This assumes that
Subclass.__new__
does not have kwargs, only positional args.[ ] The proxy should express__tensor_flatten__
and__tensor_unflatten__
__torch_dispatch__
and get the correct output type. based on 1393 #1394