Skip to content

Commit

Permalink
Remove unused use_rematerialization option from transform_for_executi…
Browse files Browse the repository at this point in the history
…on and jit (#1320)
  • Loading branch information
wujingyue authored Oct 17, 2024
1 parent 40904c7 commit 2253a7b
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 11 deletions.
1 change: 0 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def jit(
sharp_edges=sharp_edges,
using_jit=True,
disable_torch_autograd_support=disable_torch_autograd,
use_rematerialization=False,
only_execute_prims=False,
disable_preprocessing=True,
compile_options=compile_options,
Expand Down
4 changes: 0 additions & 4 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def __init__(
only_execute_prims: bool = False,
disable_preprocessing: bool = False,
disable_torch_autograd_support: bool = False,
use_rematerialization: bool = False,
debug_log: None | StringIO = None,
compile_options: dict[str, Any] = {},
get_computation_and_inputs: Callable | None = None,
Expand Down Expand Up @@ -254,7 +253,6 @@ def __init__(
self.fn = fn
self.only_execute_prims = only_execute_prims
self.disable_preprocessing = disable_preprocessing
self.use_rematerialization = use_rematerialization
self.disable_torch_autograd_support = disable_torch_autograd_support
self.debug_log = debug_log

Expand Down Expand Up @@ -636,7 +634,6 @@ def transform_for_execution(
executors_list: Sequence[Executor],
*,
only_execute_prims=False,
use_rematerialization=True,
use_del_last_used=True,
) -> list[TraceCtx]:
traces: list[TraceCtx] = []
Expand Down Expand Up @@ -679,7 +676,6 @@ def _execute_trace(
trc,
executors_list=compile_data.executors_list,
only_execute_prims=compile_data.only_execute_prims,
use_rematerialization=compile_data.use_rematerialization,
)
extrace = extraces[-1]

Expand Down
8 changes: 2 additions & 6 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@ def func(a, b, *, c):
trace = wrap_return_value_together_with_argments(trace)
fw_trace, bw_trace = forward_and_backward_from_trace(trace)

fw_extraces = transform_for_execution(
fw_trace, executors_list=executor.executors_list(), use_rematerialization=False
)
bw_extraces = transform_for_execution(
bw_trace, executors_list=executor.executors_list(), use_rematerialization=False
)
fw_extraces = transform_for_execution(fw_trace, executors_list=executor.executors_list())
bw_extraces = transform_for_execution(bw_trace, executors_list=executor.executors_list())
fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extraces[-1], bw_extraces[-1])

fw = fw_extrace.python_callable()
Expand Down

0 comments on commit 2253a7b

Please sign in to comment.