Skip to content

Commit

Permalink
wrap_return_value_together_with_argments -> `wrap_return_value_toge…
Browse files Browse the repository at this point in the history
…ther_with_arguments`

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 13, 2024
1 parent 7912102 commit ed4765e
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from thunder.core.transform_common import (
dce,
Transform,
wrap_return_value_together_with_argments,
wrap_return_value_together_with_arguments,
unwrap_return_value,
remove_context_manager_prims_from_trace,
)
Expand Down Expand Up @@ -535,7 +535,7 @@ def get_computation_and_inputs(*args, **kwargs):
prologue_traces = [prologue_trc]
computation_traces = [computation_trc]

computation_trc = wrap_return_value_together_with_argments(computation_trc)
computation_trc = wrap_return_value_together_with_arguments(computation_trc)
computation_traces.append(computation_trc)

computation_trc = remove_context_manager_prims_from_trace(computation_trc)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl
if bsym in bsym_to_copy_bsyms:
functionalized_bsyms.extend(bsym_to_copy_bsyms[bsym])
copy_bsym = functionalized_bsyms[-1]
# wrap_return_value_together_with_argments places all the arguments in the return value
# wrap_return_value_together_with_arguments places all the arguments in the return value
# We swap these arguments in the return value with the outputs of copies onto them
# This prevents subsequent transforms from ordering the return statement before those copies
swap_map_for_return[variableify(copy_bsym.flat_proxy_args[0])] = copy_bsym.flat_proxy_outs[0]
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def process_bound_symbols(src_bound_symbols, target_bound_symbols):
return output


def wrap_return_value_together_with_argments(trace: Trace) -> Trace:
def wrap_return_value_together_with_arguments(trace: Trace) -> Trace:
last = trace.bound_symbols[-1]
assert last.sym.id == prims.PrimIDs.RETURN
flat_args, _ = tree_flatten((trace.args, trace.kwargs))
Expand Down
4 changes: 2 additions & 2 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from thunder.core.transform_common import (
dce,
Transform,
wrap_return_value_together_with_argments,
wrap_return_value_together_with_arguments,
unwrap_return_value,
VJPDual,
)
Expand Down Expand Up @@ -1493,7 +1493,7 @@ def python_callable(*args, **kwargs):
grad(python_callable), *computation_trc.args, **computation_trc.kwargs
)

gradtrc = wrap_return_value_together_with_argments(gradtrc)
gradtrc = wrap_return_value_together_with_arguments(gradtrc)
gradtrc = dce(gradtrc)
return prologue_trc, gradtrc, epilogue_trc

Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def test_forward_and_backward_from_trace(executor, device, _):
from thunder.clang import cos, sin
import thunder.torch as ltorch
from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad
from thunder.core.transform_common import wrap_return_value_together_with_argments
from thunder.core.transform_common import wrap_return_value_together_with_arguments

def func(a, b, *, c):
d = a + b + c
Expand All @@ -1145,7 +1145,7 @@ def func(a, b, *, c):
b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True)
initial_trace = trace(inline_trace=False)(func, a, b, c=c)
wrapped_trace = wrap_return_value_together_with_argments(initial_trace)
wrapped_trace = wrap_return_value_together_with_arguments(initial_trace)
fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace)
fw = executor.make_callable(fw_trace)
bw = executor.make_callable(bw_trace)
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_rematerialization_with_forward_and_backward_from_trace(executor: TestEx
from thunder.clang import cos, sin
import thunder.torch as ltorch
from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad
from thunder.core.transform_common import wrap_return_value_together_with_argments
from thunder.core.transform_common import wrap_return_value_together_with_arguments
from thunder.common import transform_for_execution
from thunder.core.rematerialization import rematerialize_forward_and_backward

Expand All @@ -74,7 +74,7 @@ def func(a, b, *, c):
requires_grad=True,
)
trace = trace(inline_trace=False)(func, a, b, c=c)
trace = wrap_return_value_together_with_argments(trace)
trace = wrap_return_value_together_with_arguments(trace)
fw_trace, bw_trace = forward_and_backward_from_trace(trace)

fw_extraces = transform_for_execution(fw_trace, executors_list=executor.executors_list())
Expand Down

0 comments on commit ed4765e

Please sign in to comment.