diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 4ca21bbb2f..cce1958a91 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -385,7 +385,7 @@ def wrapper(*args, **kwargs): def record_source_loc_in_symbol_header(fn): @wraps(fn) def wrapper(*args, **kwargs): - runtimectx: Interpreterruntimectx = get_interpreterruntimectx() + runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx() filename, positions = runtimectx.get_current_user_source_location() ctx: JitCtx = get_jit_ctx() ctx._computation_trace.set_current_source_location(filename, positions) @@ -904,7 +904,7 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: lookaside = executor_lookaside # the ad hoc executor may be extended during compilation elif (executor_lookaside := ctx.ad_hoc_executor._lookasides.get(fn, None)) is not None: - lookaside = jit_needs_wrap(executor_lookaside) + lookaside = interpreter_needs_wrap(executor_lookaside) elif isinstance(fn, Symbol) or fn in _clang_fn_set: # Performs symbol lookasides # NOTE Symbols "lookaside" to themselves; this just prevents their internals from being jitted @@ -1007,7 +1007,7 @@ def _maybe_update_proxy_name(orig_value: Any, name: str, is_internal: bool | Non } if is_internal is None: - runtimectx: Interpreterruntimectx = get_interpreterruntimectx() + runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx() frame = runtimectx.peek_frame_stack() assert frame is not None # pass is_internal if you call this before the frame is set up is_internal = frame.module in {"thunder.core.interpreter", "thunder.core.jit_ext"}