diff --git a/thunder/__init__.py b/thunder/__init__.py index c09c1cc9b..5f9bd9f52 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -32,7 +32,7 @@ from thunder.core.transform_common import ( dce, Transform, - wrap_return_value_together_with_argments, + wrap_return_value_together_with_arguments, unwrap_return_value, remove_context_manager_prims_from_trace, ) @@ -535,7 +535,7 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces = [prologue_trc] computation_traces = [computation_trc] - computation_trc = wrap_return_value_together_with_argments(computation_trc) + computation_trc = wrap_return_value_together_with_arguments(computation_trc) computation_traces.append(computation_trc) computation_trc = remove_context_manager_prims_from_trace(computation_trc) diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py index aac6f1502..ec95a038b 100644 --- a/thunder/core/functionalization.py +++ b/thunder/core/functionalization.py @@ -604,7 +604,7 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl if bsym in bsym_to_copy_bsyms: functionalized_bsyms.extend(bsym_to_copy_bsyms[bsym]) copy_bsym = functionalized_bsyms[-1] - # wrap_return_value_together_with_argments places all the arguments in the return value + # wrap_return_value_together_with_arguments places all the arguments in the return value # We swap these arguments in the return value with the outputs of copies onto them # This prevents subsequent transforms from ordering the return statement before those copies swap_map_for_return[variableify(copy_bsym.flat_proxy_args[0])] = copy_bsym.flat_proxy_outs[0] diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index e21ca4b28..5ed68c417 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -456,7 +456,7 @@ def process_bound_symbols(src_bound_symbols, target_bound_symbols): return output -def wrap_return_value_together_with_argments(trace: Trace) -> Trace: +def wrap_return_value_together_with_arguments(trace: Trace) -> Trace: last = trace.bound_symbols[-1] assert last.sym.id == prims.PrimIDs.RETURN flat_args, _ = tree_flatten((trace.args, trace.kwargs)) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 2bf91372f..2d9d88cdd 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -75,7 +75,7 @@ from thunder.core.transform_common import ( dce, Transform, - wrap_return_value_together_with_argments, + wrap_return_value_together_with_arguments, unwrap_return_value, VJPDual, ) @@ -1493,7 +1493,7 @@ def python_callable(*args, **kwargs): grad(python_callable), *computation_trc.args, **computation_trc.kwargs ) - gradtrc = wrap_return_value_together_with_argments(gradtrc) + gradtrc = wrap_return_value_together_with_arguments(gradtrc) gradtrc = dce(gradtrc) return prologue_trc, gradtrc, epilogue_trc diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 8f64ea449..87ac51af8 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1134,7 +1134,7 @@ def test_forward_and_backward_from_trace(executor, device, _): from thunder.clang import cos, sin import thunder.torch as ltorch from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad - from thunder.core.transform_common import wrap_return_value_together_with_argments + from thunder.core.transform_common import wrap_return_value_together_with_arguments def func(a, b, *, c): d = a + b + c @@ -1145,7 +1145,7 @@ def func(a, b, *, c): b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True) initial_trace = trace(inline_trace=False)(func, a, b, c=c) - wrapped_trace = wrap_return_value_together_with_argments(initial_trace) + wrapped_trace = wrap_return_value_together_with_arguments(initial_trace) fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace) fw = executor.make_callable(fw_trace) bw = executor.make_callable(bw_trace) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index c42c183b7..32a3eee7c 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -53,7 +53,7 @@ def test_rematerialization_with_forward_and_backward_from_trace(executor: TestEx from thunder.clang import cos, sin import thunder.torch as ltorch from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad - from thunder.core.transform_common import wrap_return_value_together_with_argments + from thunder.core.transform_common import wrap_return_value_together_with_arguments from thunder.common import transform_for_execution from thunder.core.rematerialization import rematerialize_forward_and_backward @@ -74,7 +74,7 @@ def func(a, b, *, c): requires_grad=True, ) trace = trace(inline_trace=False)(func, a, b, c=c) - trace = wrap_return_value_together_with_argments(trace) + trace = wrap_return_value_together_with_arguments(trace) fw_trace, bw_trace = forward_and_backward_from_trace(trace) fw_extraces = transform_for_execution(fw_trace, executors_list=executor.executors_list())