From db58b670990671c8adaa0ad3331060b7f7ad994f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 18 Nov 2024 15:11:22 -0800 Subject: [PATCH 1/3] refactor --- thunder/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 5f4b75d6e..92183ed1c 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -575,15 +575,13 @@ def get_computation_and_inputs(*args, **kwargs): arg_to_idx[a] = i tensor_indices: int = [] - tensor_args_consumed_by_inplace_grouped_by_numel: dict[int, list[TensorProxy]] = defaultdict(list) for bsym in filter(lambda b: b.sym.id == prims.PrimIDs.COPY_, computation_trc.bound_symbols): t = bsym.flat_proxy_args[1] index = arg_to_idx[t] - numel = t.numel - tensor_args_consumed_by_inplace_grouped_by_numel[numel].append(index) tensor_indices.append(index) - if len(tensor_args_consumed_by_inplace_grouped_by_numel) > 1: - vanilla_tensor_args = set(tensor_indices) + tmp_vanilla_tensor_args = set(tensor_indices) + if len(tmp_vanilla_tensor_args) > 1: + vanilla_tensor_args = tmp_vanilla_tensor_args if epilogue_trc is not None: epilogue_traces = [epilogue_trc] From 59caa154184a420985b6571cb7688bd67bf56df3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 27 Nov 2024 15:30:31 -0800 Subject: [PATCH 2/3] what does this check even do?! --- thunder/__init__.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 7a602c29f..cd17b93af 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -264,7 +264,6 @@ def _recursive_jit_call_warning() -> None: "backward_fn", "backward_traces", "return_none_instead_of_grads", - "vanilla_tensor_args", ], ) @@ -548,9 +547,7 @@ def get_computation_and_inputs(*args, **kwargs): computation_traces.append(computation_trc) orig_to_view_swap_map = check_inplace_to_views(computation_trc) - vanilla_tensor_args: set[int] | None = None if not compile_options.get("skip_inplace_functionalization", False): - orig_len = len(computation_traces) alias_tensor_indices = [] if alias_tensor_indices_str := cache_info["alias_tensor_indices"]: alias_tensor_indices: list[list[int]] = [ @@ -564,25 +561,6 @@ def get_computation_and_inputs(*args, **kwargs): ) ) computation_trc = computation_traces[-1] - if len(computation_traces) > orig_len: - from thunder.core.pytree import tree_flatten - from thunder.core.utils import ProxyDict - - flat_args, _ = tree_flatten((computation_trc.args, computation_trc.kwargs)) - arg_to_idx = ProxyDict() - for i, a in enumerate(flat_args): - if not isinstance(a, TensorProxy): - continue - arg_to_idx[a] = i - - tensor_indices: int = [] - for bsym in filter(lambda b: b.sym.id == prims.PrimIDs.COPY_, computation_trc.bound_symbols): - t = bsym.flat_proxy_args[1] - index = arg_to_idx[t] - tensor_indices.append(index) - tmp_vanilla_tensor_args = set(tensor_indices) - if len(tmp_vanilla_tensor_args) > 1: - vanilla_tensor_args = tmp_vanilla_tensor_args if epilogue_trc is not None: epilogue_traces = [epilogue_trc] @@ -722,7 +700,6 @@ def get_computation_and_inputs(*args, **kwargs): backward_fn, backward_traces, return_none_instead_of_grads, - vanilla_tensor_args, ) if cd.cache_option is not CACHE_OPTIONS.NO_CACHING: cs.interpreter_cache.append(cache_entry) @@ -775,18 +752,6 @@ def wrapped(*args, **kwargs): return wrapped - def check_storage_aliases(cache_entry, args): - if cache_entry.vanilla_tensor_args: - if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args): - alias_tensor_indices = alias_tensor_indices_str - alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} - vanilla_tensor_args = cache_entry.vanilla_tensor_args - check( - not vanilla_tensor_args & alias_tensor_indices, - lambda: f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share their storage and some of them are modified in-place", - NotImplementedError, - ) - def maybe_connect_to_autograd(cache_entry, result): if cache_entry.backward_fn: # If the backward function is available, we need to connect the @@ -821,8 +786,6 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) - check_storage_aliases(cache_entry, inps) - result = cache_entry.computation_fn(*inps) result = maybe_connect_to_autograd(cache_entry, result) result = maybe_call_epilogue(cache_entry, result, pro_to_epi) From 3316b428fb2d8054f28d7ecabdffa47ec72e4c3d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 27 Nov 2024 15:34:21 -0800 Subject: [PATCH 3/3] removing vanilla_args --- thunder/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index cd17b93af..9ca25dea3 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -465,7 +465,6 @@ def get_computation_and_inputs(*args, **kwargs): backward_fn, backward_traces, _return_none_instead_of_grads, - _vanilla_args, ) = cache_entry try: inps, pro_to_epi = pro(*args, **kwargs)