diff --git a/thunder/__init__.py b/thunder/__init__.py index d0e8fa2aca..438958b425 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -593,8 +593,8 @@ def get_computation_and_inputs(*args, **kwargs): executors_list=(pythonex,), use_del_last_used=False, ) - protrace = prologue_traces[-1] - pro = protrace.python_callable() + prologue_trc = prologue_traces[-1] + pro = prologue_trc.python_callable() if epilogue_trc is not None: epilogue = epilogue_trc.python_callable() @@ -602,6 +602,13 @@ def get_computation_and_inputs(*args, **kwargs): epilogue = None cs.last_prologue_transformation_stop = time.perf_counter_ns() + cs.last_prologue_traces = prologue_traces + cs.last_prologue = pro + cs.last_traces = computation_traces + backward_traces = [] + cs.last_backward_traces = backward_traces + cs.last_interpreter_log = last_interpreter_log + cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction)) cs.last_prologue_execution_start = time.perf_counter_ns() if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON: @@ -611,12 +618,6 @@ def get_computation_and_inputs(*args, **kwargs): pro_to_epi = None cs.last_prologue_execution_stop = time.perf_counter_ns() - cs.last_traces = computation_traces - backward_traces = [] - cs.last_backward_traces = backward_traces - cs.last_interpreter_log = last_interpreter_log - cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction)) - computation_trc = dce(computation_trc) computation_traces.append(computation_trc) @@ -633,7 +634,6 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward - extraces = cs.last_traces if backward_trc is None: ## EPILOGUE and TRANSFORMS should not mix... @@ -666,14 +666,14 @@ def get_computation_and_inputs(*args, **kwargs): executors_list=cd.executors_list, use_del_last_used=False, ) - computation_traces.append(computation_trc) - computation_trc = extraces[-1] + computation_traces.extend(extraces) + computation_trc = computation_traces[-1] if cd.use_cudagraphs: from thunder.executors.cudagraphex import cudagraphex computation_trc = cudagraphex.fusion_pass(computation_trc) - extraces.append(computation_trc) + computation_traces.append(computation_trc) if backward_trc is not None: backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0])) @@ -684,7 +684,7 @@ def get_computation_and_inputs(*args, **kwargs): if not compile_options.get("disable_inplace_copy_check", False): thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) - extraces.append(computation_trc) + computation_traces.append(computation_trc) for transform in transforms: # NOTE: `backward_trc` could be None. @@ -693,7 +693,7 @@ def get_computation_and_inputs(*args, **kwargs): ) if new_computation_trc is not computation_trc: computation_trc = new_computation_trc - extraces.append(computation_trc) + computation_traces.append(computation_trc) if backward_trc is not None: new_backward_trc = transform.transform_trace_post_optimization( backward_trc, executors_list=cd.executors_list @@ -715,7 +715,7 @@ def get_computation_and_inputs(*args, **kwargs): pro, prologue_traces, comp, - extraces, + computation_traces, epilogue, epilogue_traces, backward_fn, @@ -726,10 +726,6 @@ def get_computation_and_inputs(*args, **kwargs): if cd.cache_option is not CACHE_OPTIONS.NO_CACHING: cs.interpreter_cache.append(cache_entry) - cs.last_traces += extraces - cs.last_prologue_traces = [prologue_trc] + prologue_traces - cs.last_prologue = pro - return cache_entry, inps, pro_to_epi cd.get_computation_and_inputs = get_computation_and_inputs diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index a54dd6cb92..f336775c5d 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1203,3 +1203,20 @@ def fn(): return torch.cuda.get_device_properties(0).major assert fn() == thunder.jit(fn)() + + +def test_failing_prologue_in_last_prologue_traces(): + # we know that this will fail in the prologue + i = 0 + + def fn(): + nonlocal i + i += 1 + return i + + jfn = thunder.jit(fn) + with pytest.raises(RuntimeError, match="Expected 1 to be equal to and have the type of 0"): + jfn() + + # make sure that we have prologue traces in the last_prologue_traces + assert len(thunder.last_prologue_traces(jfn)) > 0