Skip to content

Commit

Permalink
printer translating thunder dtype/device to torchs
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 24, 2024
1 parent 77e27d3 commit 9d5d1aa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
10 changes: 9 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4118,8 +4118,16 @@ def printer_of_tensor_subclass_ctor(
else:
cls: torch._C._TensorMeta = wrapped_cls.obj
tensors, non_tensors = arg_printables[-2:]
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)

arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *non_tensors])
arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors])
kwarg_str = ""

result_str: str
Expand Down
10 changes: 9 additions & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,15 @@ def _shape_impl(t):


def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors):
return cls(*tensors, *non_tensors)
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(to_torch_device(a))
else:
new_non_tensors.append(a)
return cls(*tensors, *new_non_tensors)


def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
Expand Down

0 comments on commit 9d5d1aa

Please sign in to comment.