diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index b000d4bf61..ef2932ac3d 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -13,11 +13,7 @@ from thunder.core.trace import from_trace, TraceCtx, TraceProvenance from thunder.core.transform_common import dce from thunder.core.pytree import tree_flatten -from thunder.executors.passes import ( - update_fusion_call_ctx, - _transform_for_operator_executor_execution, - transform_for_execution, -) +from thunder.executors.passes import update_fusion_call_ctx from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo from thunder.core.compile_data import get_compile_option @@ -46,7 +42,15 @@ def _to_torch(*args, **kwargs) -> Any: return impl_info.execution_transform(*args, **kwargs) if torch_op is None: - torch_op = torchex.opmap[bsym.sym.name] + torch_op = torchex.opmap.get(bsym.sym.name) + + # this should be really rare, but type_as has this, + # ideally we would be also handling more subsymbols here + if torch_op is None and len(bsym.subsymbols) == 1: + torch_op = torchex.opmap.get(bsym.subsymbols[0].sym.name) + + if torch_op is None: + raise RuntimeError("op not found for {bsym.sym.name}") return torch_op(*args, **kwargs) @@ -59,31 +63,36 @@ def make_compiled( from thunder import trace from thunder.core.transforms import eval_trace from thunder.executors.torchex import no_autocast - from thunder.executors.torchex import ex as torchex - from thunder.executors.pythonex import ex as pythonex - from thunder.core.codeutils import SigInfo # Here we construct a trace that will be used to compile the function - # TODO: maybe we should have a utility that does this properly region_trace = TraceCtx(None) region_trace.bound_symbols = list(bsyms) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=())) - for a in region_trace.args: - region_trace.add_name(a.name) - for bsym in region_trace.bound_symbols: - for o in bsym.flat_outs: - if o is not None: # TODO: investigate - region_trace.add_name(o.name) - - # maybe make this the default if no sig info is present? - region_trace._siginfo = SigInfo("to_be_compiled") - region_trace._siginfo.args = [(a.name, None) for a in region_trace.args] - - torchex_trace = transform_for_execution(region_trace, executors_list=(torchex,)) - trace_callable = torchex_trace.python_callable(include_decorators=False) + def torch_interpreted_func(*args): + return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) + + # Here instead of using thunder.trace we could use torch_trace = + # passes._transform_for_operator_executor_execution(region_trace, [torchex]) + # but then we would need to handle unpacking of the args explicitly For + # example with: + # try: + # token = set_tracectx(region_trace) + # col = CollectionProxy(region_trace.args, name="args") + # _ = prims.unpack_sequence(col, len(region_trace.args)) + # finally: + # reset_tracectx(token) + # region_trace.bound_symbols.extend(bsyms) + # But there are some issues with the + # _transform_for_operator_executor_execution implementation that need to be + # fixed first. One issue is that it doesn't maintain the ssa form of the + # trace, which is needed for all the passes to work correctly. + # TODO: issue "Try using _transform_for_operator_executor_execution for + # torch.compile executor" + torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs) + trace_callable = torch_trace.python_callable(include_decorators=False) torch_compile_fullgraph: None | bool = get_compile_option( "torch_compile_fullgraph", "Whether to enable `fullgraph` from `torch.compile`. Defaults to `True`." ) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 3f7b331387..48e37f8a89 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -9,7 +9,6 @@ from thunder.tests.bf16 import device_supports_bf16 from thunder.tests.litgpt_model import GPT, Config from thunder.tests.framework import requiresCUDA -from torch.testing import assert_close def test_supported_ops_are_in_pytorch_executor(): @@ -72,13 +71,3 @@ def test_torch_compile_cat_rope_single_fusion(): backward_execution_trace = thunder.last_backward_traces(jfn)[-1] assert len(get_fusions(backward_execution_trace)) == 1 assert len(backward_execution_trace.bound_symbols) == 14 - - -@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") -def test_transform_for_execution_for_callable(): - def fn(a): - return a.type("torch.DoubleTensor") - - a = torch.randn(3) - jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) - assert_close(jfn(a), fn(a))