Skip to content

Commit

Permalink
add last_prologue_traces earlier (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 3, 2024
1 parent 7660cd5 commit 510d8bf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
34 changes: 15 additions & 19 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,22 @@ def get_computation_and_inputs(*args, **kwargs):
executors_list=(pythonex,),
use_del_last_used=False,
)
protrace = prologue_traces[-1]
pro = protrace.python_callable()
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable()

if epilogue_trc is not None:
epilogue = epilogue_trc.python_callable()
else:
epilogue = None

cs.last_prologue_transformation_stop = time.perf_counter_ns()
cs.last_prologue_traces = prologue_traces
cs.last_prologue = pro
cs.last_traces = computation_traces
backward_traces = []
cs.last_backward_traces = backward_traces
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

cs.last_prologue_execution_start = time.perf_counter_ns()
if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
Expand All @@ -611,12 +618,6 @@ def get_computation_and_inputs(*args, **kwargs):
pro_to_epi = None
cs.last_prologue_execution_stop = time.perf_counter_ns()

cs.last_traces = computation_traces
backward_traces = []
cs.last_backward_traces = backward_traces
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

Expand All @@ -633,7 +634,6 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward
extraces = cs.last_traces

if backward_trc is None:
## EPILOGUE and TRANSFORMS should not mix...
Expand Down Expand Up @@ -666,14 +666,14 @@ def get_computation_and_inputs(*args, **kwargs):
executors_list=cd.executors_list,
use_del_last_used=False,
)
computation_traces.append(computation_trc)
computation_trc = extraces[-1]
computation_traces.extend(extraces)
computation_trc = computation_traces[-1]

if cd.use_cudagraphs:
from thunder.executors.cudagraphex import cudagraphex

computation_trc = cudagraphex.fusion_pass(computation_trc)
extraces.append(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is not None:
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
Expand All @@ -684,7 +684,7 @@ def get_computation_and_inputs(*args, **kwargs):

if not compile_options.get("disable_inplace_copy_check", False):
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
extraces.append(computation_trc)
computation_traces.append(computation_trc)

for transform in transforms:
# NOTE: `backward_trc` could be None.
Expand All @@ -693,7 +693,7 @@ def get_computation_and_inputs(*args, **kwargs):
)
if new_computation_trc is not computation_trc:
computation_trc = new_computation_trc
extraces.append(computation_trc)
computation_traces.append(computation_trc)
if backward_trc is not None:
new_backward_trc = transform.transform_trace_post_optimization(
backward_trc, executors_list=cd.executors_list
Expand All @@ -715,7 +715,7 @@ def get_computation_and_inputs(*args, **kwargs):
pro,
prologue_traces,
comp,
extraces,
computation_traces,
epilogue,
epilogue_traces,
backward_fn,
Expand All @@ -726,10 +726,6 @@ def get_computation_and_inputs(*args, **kwargs):
if cd.cache_option is not CACHE_OPTIONS.NO_CACHING:
cs.interpreter_cache.append(cache_entry)

cs.last_traces += extraces
cs.last_prologue_traces = [prologue_trc] + prologue_traces
cs.last_prologue = pro

return cache_entry, inps, pro_to_epi

cd.get_computation_and_inputs = get_computation_and_inputs
Expand Down
17 changes: 17 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,3 +1203,20 @@ def fn():
return torch.cuda.get_device_properties(0).major

assert fn() == thunder.jit(fn)()


def test_failing_prologue_in_last_prologue_traces():
# we know that this will fail in the prologue
i = 0

def fn():
nonlocal i
i += 1
return i

jfn = thunder.jit(fn)
with pytest.raises(RuntimeError, match="Expected 1 to be equal to and have the type of 0"):
jfn()

# make sure that we have prologue traces in the last_prologue_traces
assert len(thunder.last_prologue_traces(jfn)) > 0

0 comments on commit 510d8bf

Please sign in to comment.