diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py index aac6f1502..fc6594748 100644 --- a/thunder/core/functionalization.py +++ b/thunder/core/functionalization.py @@ -41,7 +41,6 @@ def bsym_of_to_return_self(bsym: BoundSymbol): def check_inplace_to_views(computation_trace: Trace) -> dict[VariableInterface, TensorProxy]: """Error out if in-place op that outputs of different number of elements from the input and the input has other consumers.""" - from thunder.core import utils import thunder.torch as ltorch producer_bsyms = producers(computation_trace) @@ -58,7 +57,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: return bsym.sym.tags and tag in bsym.sym.tags swap_map: dict[VariableInterface, TensorProxy] = {} - consumers = utils.consumers(computation_trace) + consumer_map = consumers(computation_trace) bsym: BoundSymbol for bsym in filter(lambda b: has_tag(b, prims.OpTags.IN_PLACE), computation_trace.bound_symbols): in_tensor: TensorProxy = list(filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))[0] @@ -72,7 +71,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: # assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones` continue orig_tensor = flat_tensor_proxy_args[0] - consumer_of_orig_tensor = consumers[orig_tensor] + consumer_of_orig_tensor = consumer_map[orig_tensor] # When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe. # Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original # is an arg or a kwarg.