Skip to content

Commit

Permalink
printer translating thunder dtype/device to torchs
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 24, 2024
1 parent 77e27d3 commit bccf751
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
10 changes: 9 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4118,8 +4118,16 @@ def printer_of_tensor_subclass_ctor(
else:
cls: torch._C._TensorMeta = wrapped_cls.obj
tensors, non_tensors = arg_printables[-2:]
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)

arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *non_tensors])
arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors])
kwarg_str = ""

result_str: str
Expand Down
17 changes: 15 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,15 @@ def _shape_impl(t):


def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors):
return cls(*tensors, *non_tensors)
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(to_torch_device(a))
else:
new_non_tensors.append(a)
return cls(*tensors, *new_non_tensors)


def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
Expand All @@ -2231,7 +2239,6 @@ def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
)
if filtered_types:
new_imports = {t.__name__: t for t in filtered_types}
print(f"$$$ [torchexecutor] {bsym.sym = }| {new_imports = }")
bsym._import_ctx.update(new_imports)


Expand Down Expand Up @@ -2268,6 +2275,12 @@ def unflatten_tensor_subclass_impl(
inner_tensors: dict[str, TensorLike],
metadata: dict,
):
for key in metadata:
v = metadata[key]
if isinstance(v, dtypes.dtype):
metadata[key] = to_torch_dtype(v)
elif isinstance(v, devices.Device):
metadata[key] = to_torch_device(v)
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


Expand Down
16 changes: 15 additions & 1 deletion thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,15 @@ def _make_fake_subclass_tensor_from_subclass_tensor_proxy(
[_materialize_tensor_proxy(t, fake_tensor_mode=fake_tensor_mode) for t in tensor_proxy._tensors],
)
)
metadata = dict(zip(non_tensor_attr_names, tensor_proxy._non_tensors))
new_non_tensors = []
for a in tensor_proxy._non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)
metadata = dict(zip(non_tensor_attr_names, new_non_tensors))
subclass_tensor = subclass_type.__tensor_unflatten__(
inner_tensors,
metadata,
Expand Down Expand Up @@ -386,8 +394,14 @@ def ctor(tensors, metadata):
def transform_out(out: torch.Tensor) -> OutputWrapperForFxTracing:
orig_output.append(out)
if is_traceable_wrapper_subclass(out):
from enum import Enum

attrs, metadata = out.__tensor_flatten__()
tensors = [getattr(out, name) for name in attrs]
for key in metadata:
v = metadata[key]
if issubclass(type(v), Enum) and not isinstance(v, (torch.dtype, torch.device)):
metadata[key] = str(metadata[key])
output = OutputWrapperForFxTracing(dict(zip(attrs, tensors)), metadata)
else:
output = OutputWrapperForFxTracing(out, None)
Expand Down

0 comments on commit bccf751

Please sign in to comment.