diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 24ed43208..b3142a05c 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4118,8 +4118,16 @@ def printer_of_tensor_subclass_ctor( else: cls: torch._C._TensorMeta = wrapped_cls.obj tensors, non_tensors = arg_printables[-2:] + new_non_tensors = [] + for a in non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(dtypes.to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(devices.to_torch_device(a)) + else: + new_non_tensors.append(a) - arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *non_tensors]) + arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors]) kwarg_str = "" result_str: str diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 625da83bd..85b2cd191 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2210,7 +2210,15 @@ def _shape_impl(t): def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors): - return cls(*tensors, *non_tensors) + new_non_tensors = [] + for a in non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(to_torch_device(a)) + else: + new_non_tensors.append(a) + return cls(*tensors, *new_non_tensors) def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: @@ -2231,7 +2239,6 @@ def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: ) if filtered_types: new_imports = {t.__name__: t for t in filtered_types} - print(f"$$$ [torchexecutor] {bsym.sym = }| {new_imports = }") bsym._import_ctx.update(new_imports) @@ -2268,6 +2275,12 @@ def unflatten_tensor_subclass_impl( inner_tensors: dict[str, TensorLike], metadata: dict, ): + for key in metadata: + v = metadata[key] + if isinstance(v, dtypes.dtype): + metadata[key] = to_torch_dtype(v) + elif isinstance(v, devices.Device): + metadata[key] = to_torch_device(v) return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1) diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index 8e901717a..4e145569d 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -94,7 +94,15 @@ def _make_fake_subclass_tensor_from_subclass_tensor_proxy( [_materialize_tensor_proxy(t, fake_tensor_mode=fake_tensor_mode) for t in tensor_proxy._tensors], ) ) - metadata = dict(zip(non_tensor_attr_names, tensor_proxy._non_tensors)) + new_non_tensors = [] + for a in tensor_proxy._non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(dtypes.to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(devices.to_torch_device(a)) + else: + new_non_tensors.append(a) + metadata = dict(zip(non_tensor_attr_names, new_non_tensors)) subclass_tensor = subclass_type.__tensor_unflatten__( inner_tensors, metadata, @@ -386,8 +394,14 @@ def ctor(tensors, metadata): def transform_out(out: torch.Tensor) -> OutputWrapperForFxTracing: orig_output.append(out) if is_traceable_wrapper_subclass(out): + from enum import Enum + attrs, metadata = out.__tensor_flatten__() tensors = [getattr(out, name) for name in attrs] + for key in metadata: + v = metadata[key] + if issubclass(type(v), Enum) and not isinstance(v, (torch.dtype, torch.device)): + metadata[key] = str(metadata[key]) output = OutputWrapperForFxTracing(dict(zip(attrs, tensors)), metadata) else: output = OutputWrapperForFxTracing(out, None)