Skip to content

Commit

Permalink
revert using transform to execution, add ad-hoc fix for type_as (#1069)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 29, 2024
1 parent a89aa48 commit 339a782
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 34 deletions.
55 changes: 32 additions & 23 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
from thunder.core.transform_common import dce
from thunder.core.pytree import tree_flatten
from thunder.executors.passes import (
update_fusion_call_ctx,
_transform_for_operator_executor_execution,
transform_for_execution,
)
from thunder.executors.passes import update_fusion_call_ctx
from thunder.executors.utils import Region
from thunder.extend import FusionExecutor, register_executor, ImplInfo
from thunder.core.compile_data import get_compile_option
Expand Down Expand Up @@ -46,7 +42,15 @@ def _to_torch(*args, **kwargs) -> Any:
return impl_info.execution_transform(*args, **kwargs)

if torch_op is None:
torch_op = torchex.opmap[bsym.sym.name]
torch_op = torchex.opmap.get(bsym.sym.name)

# this should be really rare, but type_as has this,
# ideally we would be also handling more subsymbols here
if torch_op is None and len(bsym.subsymbols) == 1:
torch_op = torchex.opmap.get(bsym.subsymbols[0].sym.name)

if torch_op is None:
raise RuntimeError("op not found for {bsym.sym.name}")

return torch_op(*args, **kwargs)

Expand All @@ -59,31 +63,36 @@ def make_compiled(
from thunder import trace
from thunder.core.transforms import eval_trace
from thunder.executors.torchex import no_autocast
from thunder.executors.torchex import ex as torchex
from thunder.executors.pythonex import ex as pythonex
from thunder.core.codeutils import SigInfo

# Here we construct a trace that will be used to compile the function
# TODO: maybe we should have a utility that does this properly
region_trace = TraceCtx(None)
region_trace.bound_symbols = list(bsyms)
region_trace.args = sorted_unique_inputs
region_trace.kwargs = {}
region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=()))
for a in region_trace.args:
region_trace.add_name(a.name)
for bsym in region_trace.bound_symbols:
for o in bsym.flat_outs:
if o is not None: # TODO: investigate
region_trace.add_name(o.name)

# maybe make this the default if no sig info is present?
region_trace._siginfo = SigInfo("to_be_compiled")
region_trace._siginfo.args = [(a.name, None) for a in region_trace.args]

torchex_trace = transform_for_execution(region_trace, executors_list=(torchex,))
trace_callable = torchex_trace.python_callable(include_decorators=False)

def torch_interpreted_func(*args):
return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator)

# Here instead of using thunder.trace we could use torch_trace =
# passes._transform_for_operator_executor_execution(region_trace, [torchex])
# but then we would need to handle unpacking of the args explicitly For
# example with:
# try:
# token = set_tracectx(region_trace)
# col = CollectionProxy(region_trace.args, name="args")
# _ = prims.unpack_sequence(col, len(region_trace.args))
# finally:
# reset_tracectx(token)
# region_trace.bound_symbols.extend(bsyms)
# But there are some issues with the
# _transform_for_operator_executor_execution implementation that need to be
# fixed first. One issue is that it doesn't maintain the ssa form of the
# trace, which is needed for all the passes to work correctly.
# TODO: issue "Try using _transform_for_operator_executor_execution for
# torch.compile executor"
torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs)
trace_callable = torch_trace.python_callable(include_decorators=False)
torch_compile_fullgraph: None | bool = get_compile_option(
"torch_compile_fullgraph", "Whether to enable `fullgraph` from `torch.compile`. Defaults to `True`."
)
Expand Down
11 changes: 0 additions & 11 deletions thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from thunder.tests.bf16 import device_supports_bf16
from thunder.tests.litgpt_model import GPT, Config
from thunder.tests.framework import requiresCUDA
from torch.testing import assert_close


def test_supported_ops_are_in_pytorch_executor():
Expand Down Expand Up @@ -72,13 +71,3 @@ def test_torch_compile_cat_rope_single_fusion():
backward_execution_trace = thunder.last_backward_traces(jfn)[-1]
assert len(get_fusions(backward_execution_trace)) == 1
assert len(backward_execution_trace.bound_symbols) == 14


@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported")
def test_transform_for_execution_for_callable():
def fn(a):
return a.type("torch.DoubleTensor")

a = torch.randn(3)
jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,))
assert_close(jfn(a), fn(a))

0 comments on commit 339a782

Please sign in to comment.