From 9d5d1aac0d011eab443846cdf0e93a036484d929 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 24 Nov 2024 16:28:45 +0900 Subject: [PATCH] printer translating thunder dtype/device to torchs Signed-off-by: Masaki Kozuki --- thunder/core/prims.py | 10 +++++++++- thunder/executors/torchex.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 24ed43208..b3142a05c 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4118,8 +4118,16 @@ def printer_of_tensor_subclass_ctor( else: cls: torch._C._TensorMeta = wrapped_cls.obj tensors, non_tensors = arg_printables[-2:] + new_non_tensors = [] + for a in non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(dtypes.to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(devices.to_torch_device(a)) + else: + new_non_tensors.append(a) - arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *non_tensors]) + arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors]) kwarg_str = "" result_str: str diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 625da83bd..a63eb25a7 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2210,7 +2210,15 @@ def _shape_impl(t): def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors): - return cls(*tensors, *non_tensors) + new_non_tensors = [] + for a in non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(to_torch_device(a)) + else: + new_non_tensors.append(a) + return cls(*tensors, *new_non_tensors) def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: