diff --git a/thunder/__init__.py b/thunder/__init__.py index 45a56a16a2..03ab5efa0b 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -348,7 +348,6 @@ def jit( sharp_edges=sharp_edges, using_jit=True, disable_torch_autograd_support=disable_torch_autograd, - use_rematerialization=False, only_execute_prims=False, disable_preprocessing=True, compile_options=compile_options, diff --git a/thunder/common.py b/thunder/common.py index 4107a00550..1725050d69 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -199,7 +199,6 @@ def __init__( only_execute_prims: bool = False, disable_preprocessing: bool = False, disable_torch_autograd_support: bool = False, - use_rematerialization: bool = False, debug_log: None | StringIO = None, compile_options: dict[str, Any] = {}, get_computation_and_inputs: Callable | None = None, @@ -254,7 +253,6 @@ def __init__( self.fn = fn self.only_execute_prims = only_execute_prims self.disable_preprocessing = disable_preprocessing - self.use_rematerialization = use_rematerialization self.disable_torch_autograd_support = disable_torch_autograd_support self.debug_log = debug_log @@ -636,7 +634,6 @@ def transform_for_execution( executors_list: Sequence[Executor], *, only_execute_prims=False, - use_rematerialization=True, use_del_last_used=True, ) -> list[TraceCtx]: traces: list[TraceCtx] = [] @@ -679,7 +676,6 @@ def _execute_trace( trc, executors_list=compile_data.executors_list, only_execute_prims=compile_data.only_execute_prims, - use_rematerialization=compile_data.use_rematerialization, ) extrace = extraces[-1] diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index f6e68f0e23..52507be74b 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -77,12 +77,8 @@ def func(a, b, *, c): trace = wrap_return_value_together_with_argments(trace) fw_trace, bw_trace = forward_and_backward_from_trace(trace) - fw_extraces = transform_for_execution( - fw_trace, executors_list=executor.executors_list(), use_rematerialization=False - ) - bw_extraces = transform_for_execution( - bw_trace, executors_list=executor.executors_list(), use_rematerialization=False - ) + fw_extraces = transform_for_execution(fw_trace, executors_list=executor.executors_list()) + bw_extraces = transform_for_execution(bw_trace, executors_list=executor.executors_list()) fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extraces[-1], bw_extraces[-1]) fw = fw_extrace.python_callable()