diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index ae7744dda3..626f500e6c 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -3,7 +3,7 @@ import copy from enum import auto, Enum from numbers import Number -from typing import Type, Optional, Any, Tuple, List, Union +from typing import Any from collections.abc import Callable from collections.abc import Sequence @@ -1974,8 +1974,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = for idx, s in enumerate(t.shape) ) else: + # NOTE Without tuple(t.shape) then the shape would be a torch.Size object shape = tuple(t.shape) - # NOTE Without tuple(t.shape) then the shape would be a torch.Size object return TensorProxy( name, shape=tuple(shape),