diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 2f7f054cd..9799ea056 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1896,12 +1896,19 @@ def __init__(self, *args, **kwargs): kwarg_non_tensors = kwargs.pop("non_tensors", []) subclass_type = kwargs.pop("subclass_type", None) + has_name_before_init = hasattr(self, "_name") # If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` # where `self` should already have gotten its name. flat_args, spec = tree_flatten((args, kwargs)) - tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) - non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) - has_name_before_init = hasattr(self, "_name") + tensors: list[TensorProxy] = [] + non_tensors: list[Any] = [] + for t in args + tuple(kwargs.values()): + if type(t) is SubclassTensorProxy: + continue + if type(t) is TensorProxy: + tensors.append(t) + else: + non_tensors.append(t) is_dunder_init_following_make_wrapper_subclass: bool = False if tensors: