From 019557e1cbd2944a4d8d719954ee8a7c295c539b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 9 Aug 2024 14:00:46 +0200 Subject: [PATCH] cleanup some minor pickle bits (#945) --- thunder/core/prims.py | 18 +++----- thunder/core/rematerialization.py | 2 +- thunder/core/symbol.py | 68 ++++++++++++---------------- thunder/distributed/utils.py | 5 +- thunder/tests/test_core.py | 13 ++++-- thunder/tests/test_examine_memory.py | 2 +- 6 files changed, 50 insertions(+), 58 deletions(-) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 3dc3d6eaf4..a4c897c57b 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -297,7 +297,6 @@ def make_prim( method_name: None | str = None, _bind_postprocess: None | Callable = None, _print_as_impl: bool = False, - python_name: str | None = None, ): sym = Symbol( name=name, @@ -309,7 +308,6 @@ def make_prim( python_impl=python_impl, _bind_postprocess=_bind_postprocess, _print_as_impl=_print_as_impl, - _python_name=python_name, ) if method_name is not None: @@ -485,10 +483,9 @@ def _check_tensor_shape_and_metadata_meta( check_tensor_shape_and_metadata = make_prim( PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA, - "check_tensor_metadata", + "check_tensor_shape_and_metadata", meta=_check_tensor_shape_and_metadata_meta, tags=(OpTags.DONT_DCE,), - python_name="check_tensor_shape_and_metadata", ) @@ -1188,7 +1185,7 @@ def pack_buffer_impl(o: Any, key: Any, v: Any) -> None: pack_buffer = make_prim( PrimIDs.PACK_BUFFER, - "unpack_buffer", + "pack_buffer", meta=pack_buffer_meta, python_printer=pack_buffer_printer, python_impl=pack_buffer_impl, @@ -1230,7 +1227,7 @@ def pack_setitem_impl(o: Any, key: Any, v: Any) -> None: pack_setitem = make_prim( PrimIDs.PACK_SETITEM, - "unpack_setitem", + "pack_setitem", meta=pack_setitem_meta, python_printer=pack_setitem_printer, python_impl=pack_setitem_impl, @@ -1560,12 +1557,11 @@ def python_print_printer( python_print = make_prim( PrimIDs.PRINT, - "print", + "python_print", meta=_print_meta, python_printer=python_print_printer, python_impl=print, tags=(OpTags.DONT_DCE,), - python_name="python_print", ) @@ -1630,11 +1626,10 @@ def _del_impl(x: Any, /) -> None: python_del = make_prim( PrimIDs.DEL, - "del", + "python_del", meta=_del_meta, python_printer=del_printer, python_impl=_del_impl, - python_name="python_del", ) @@ -1667,12 +1662,11 @@ def _return_impl(*args) -> Any: python_return = make_prim( PrimIDs.RETURN, - "return", + "python_return", meta=_return_meta, python_printer=return_printer, python_impl=_return_impl, tags=(OpTags.DONT_DCE,), - python_name="python_return", ) # diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 8ee9d67bfb..1aa01d529e 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -54,7 +54,7 @@ def is_rematerializable(out: ProxyInterface): # to see if the output is used by other consumers. global_consumers = proxy_to_consumers.get(out, tuple()) global_consumers = tuple( - x for x in global_consumers if x.sym.name != "del" and x not in chain((consumer,), next_consumers) + x for x in global_consumers if x.sym is not prims.python_del and x not in chain((consumer,), next_consumers) ) # If the output is used by other global consumers, it's not rematerializable. diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index fe6a872f78..4d18d30219 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -138,7 +138,6 @@ class Symbol: executor: None | Any = None python_impl: None | Callable = None _print_as_impl: bool = False # If not None, w - _python_name: str | None = None # An optional postprocessing function to modify the bound symbol resulting from bind() _bind_postprocess: None | Callable = None @@ -200,54 +199,45 @@ def module(self) -> None | ModuleType: result = inspect.getmodule(fn_) return result - @classmethod - def lookup_from_module(cls, name: str, executor: Any, module: ModuleType) -> Symbol: # For unpickling - if module not in sys.modules: - raise RuntimeError(f"Cannot find module {module} for symbol {name}.") - + @staticmethod + def lookup_symbol(name: str, executor: Any, module: ModuleType) -> Symbol: # for unpickling if executor is None: + if module not in sys.modules: + raise RuntimeError(f"Cannot find module {module} for symbol {name}.") not_found = object() sym = getattr(sys.modules[module], name, not_found) if sym is not_found: raise RuntimeError(f"Could not find symbol {name} in module {module}.") - assert isinstance(sym, Symbol), (name, module, type(sym), sym) - return sym + assert isinstance(sym, Symbol), f"lookup {module}.{name} gave object of type {type(sym)} instead of Symbol" else: - # Try to find the executor in all_executors - import thunder.extend - - executors = thunder.extend.get_all_executors() - - for ex in executors: - implmap = ex.implmap.values() - - for key, info in implmap: - assert isinstance(key.id, str) - if key.id == name: - if ( - impl.symbol is not None - and module is not None - and impl.module is not None - and module != impl.module - ): - continue - return lookup_from_module(name, ex, module) - - raise ValueError(f"Could not find an executor for symbol {name} from module {module.__qualname__}.") - - def __reduce__(self): # For pickling - if self.module is None: - raise ValueError("Cannot serialize a symbol without a module.") - - if hasattr(self, "_python_name") and not self._python_name is None: - name = self._python_name + import thunder + + ex = thunder.get_executor(executor) + sym = ex.opmap.get(name) + + if sym is None: + raise RuntimeError(f"Could not find symbol {name} in executor {executor}.") + assert isinstance( + sym, Symbol + ), f"lookup {name} in executor {executor} gave object of type {type(sym)} instead of Symbol" + + return sym + + def __reduce__(self): # for pickling + import thunder + + if self.module is None and self.executor is None: + raise ValueError("Cannot serialize a symbol without a module and executor.") + + if self.executor is None: + assert getattr(sys.modules[self.module.__name__], self.name, None) is self else: - name = self.name + assert thunder.get_executor(self.executor.name).opmap.get(self.name) is self return ( - Symbol.lookup_from_module, + Symbol.lookup_symbol, ( - name, + self.name, None if self.executor is None else self.executor.name, None if self.module is None else self.module.__name__, ), diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index b0d2b97818..fbd5f8d0ad 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +import thunder from thunder.core.trace import from_trace from thunder.core.transforms import bsym_list_to_dag, Node, toposort_bsym_dag, TOPOSORT_ORDER from thunder.core.utils import check @@ -105,7 +106,7 @@ def key(node: Node) -> int: # TODO: This pass doesn't behave correctly if del nodes are present in the trace check( - not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols), + not any(bsym.sym is thunder.core.prims.python_del for bsym in execution_trace.bound_symbols), lambda: "Cannot sort execution trace with del nodes", ) new_execution_trace.bound_symbols = toposort_bsym_dag( @@ -165,7 +166,7 @@ def key(node: Node) -> int: # TODO: This pass doesn't behave correctly if del nodes are present in the trace check( - not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols), + not any(bsym.sym is thunder.core.prims.python_del for bsym in execution_trace.bound_symbols), lambda: "Cannot sort execution trace with del nodes", ) new_execution_trace.bound_symbols = toposort_bsym_dag( diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index bea42ee87c..25c58935b8 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2785,16 +2785,23 @@ def foo2(x): def test_serialize_trace(): import dill as pickle - def fn(a, b): - return a + b + def fn(a, b, l): + res = a + b + for t in l: + res = res + t + return res tm = thunder.jit(fn) a, b = torch.randn(2, 5, device=("cuda" if torch.cuda.is_available() else "cpu")) - tm(a, b) + tm(a, b, [a, b]) trace = thunder.last_traces(tm)[0] assert str(pickle.loads(pickle.dumps(trace))) == str(trace) + prologue_trace = thunder.last_prologue_traces(tm)[0] + + assert str(pickle.loads(pickle.dumps(prologue_trace))) == str(prologue_trace) + @pytest.mark.parametrize("requires_grad", (True, False)) def test_dataclass_output(requires_grad): diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 724975200a..b2578666e1 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -29,7 +29,7 @@ def runtime_allocated_memory(dev): def get_return_memory(bsym): - assert bsym.sym.name == "return" + assert bsym.sym is thunder.core.prims.python_return return_tensors_name = set() res = 0 for x in bsym.flat_proxy_args: