Skip to content

Commit

Permalink
use the already imported consumers, not through utils.consumers (#…
Browse files Browse the repository at this point in the history
…1436)

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Nov 13, 2024
1 parent 3d27367 commit 85dc00b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def bsym_of_to_return_self(bsym: BoundSymbol):

def check_inplace_to_views(computation_trace: Trace) -> dict[VariableInterface, TensorProxy]:
"""Error out if in-place op that outputs of different number of elements from the input and the input has other consumers."""
from thunder.core import utils
import thunder.torch as ltorch

producer_bsyms = producers(computation_trace)
Expand All @@ -58,7 +57,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
return bsym.sym.tags and tag in bsym.sym.tags

swap_map: dict[VariableInterface, TensorProxy] = {}
consumers = utils.consumers(computation_trace)
consumer_map = consumers(computation_trace)
bsym: BoundSymbol
for bsym in filter(lambda b: has_tag(b, prims.OpTags.IN_PLACE), computation_trace.bound_symbols):
in_tensor: TensorProxy = list(filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))[0]
Expand All @@ -72,7 +71,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
# assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones`
continue
orig_tensor = flat_tensor_proxy_args[0]
consumer_of_orig_tensor = consumers[orig_tensor]
consumer_of_orig_tensor = consumer_map[orig_tensor]
# When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe.
# Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original
# is an arg or a kwarg.
Expand Down

0 comments on commit 85dc00b

Please sign in to comment.