Skip to content

Commit

Permalink
cleanup some minor pickle bits (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 9, 2024
1 parent 9f6e5b1 commit 019557e
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 58 deletions.
18 changes: 6 additions & 12 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def make_prim(
method_name: None | str = None,
_bind_postprocess: None | Callable = None,
_print_as_impl: bool = False,
python_name: str | None = None,
):
sym = Symbol(
name=name,
Expand All @@ -309,7 +308,6 @@ def make_prim(
python_impl=python_impl,
_bind_postprocess=_bind_postprocess,
_print_as_impl=_print_as_impl,
_python_name=python_name,
)

if method_name is not None:
Expand Down Expand Up @@ -485,10 +483,9 @@ def _check_tensor_shape_and_metadata_meta(

check_tensor_shape_and_metadata = make_prim(
PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA,
"check_tensor_metadata",
"check_tensor_shape_and_metadata",
meta=_check_tensor_shape_and_metadata_meta,
tags=(OpTags.DONT_DCE,),
python_name="check_tensor_shape_and_metadata",
)


Expand Down Expand Up @@ -1188,7 +1185,7 @@ def pack_buffer_impl(o: Any, key: Any, v: Any) -> None:

pack_buffer = make_prim(
PrimIDs.PACK_BUFFER,
"unpack_buffer",
"pack_buffer",
meta=pack_buffer_meta,
python_printer=pack_buffer_printer,
python_impl=pack_buffer_impl,
Expand Down Expand Up @@ -1230,7 +1227,7 @@ def pack_setitem_impl(o: Any, key: Any, v: Any) -> None:

pack_setitem = make_prim(
PrimIDs.PACK_SETITEM,
"unpack_setitem",
"pack_setitem",
meta=pack_setitem_meta,
python_printer=pack_setitem_printer,
python_impl=pack_setitem_impl,
Expand Down Expand Up @@ -1560,12 +1557,11 @@ def python_print_printer(

python_print = make_prim(
PrimIDs.PRINT,
"print",
"python_print",
meta=_print_meta,
python_printer=python_print_printer,
python_impl=print,
tags=(OpTags.DONT_DCE,),
python_name="python_print",
)


Expand Down Expand Up @@ -1630,11 +1626,10 @@ def _del_impl(x: Any, /) -> None:

python_del = make_prim(
PrimIDs.DEL,
"del",
"python_del",
meta=_del_meta,
python_printer=del_printer,
python_impl=_del_impl,
python_name="python_del",
)


Expand Down Expand Up @@ -1667,12 +1662,11 @@ def _return_impl(*args) -> Any:

python_return = make_prim(
PrimIDs.RETURN,
"return",
"python_return",
meta=_return_meta,
python_printer=return_printer,
python_impl=_return_impl,
tags=(OpTags.DONT_DCE,),
python_name="python_return",
)

#
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def is_rematerializable(out: ProxyInterface):
# to see if the output is used by other consumers.
global_consumers = proxy_to_consumers.get(out, tuple())
global_consumers = tuple(
x for x in global_consumers if x.sym.name != "del" and x not in chain((consumer,), next_consumers)
x for x in global_consumers if x.sym is not prims.python_del and x not in chain((consumer,), next_consumers)
)

# If the output is used by other global consumers, it's not rematerializable.
Expand Down
68 changes: 29 additions & 39 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class Symbol:
executor: None | Any = None
python_impl: None | Callable = None
_print_as_impl: bool = False # If not None, w
_python_name: str | None = None

# An optional postprocessing function to modify the bound symbol resulting from bind()
_bind_postprocess: None | Callable = None
Expand Down Expand Up @@ -200,54 +199,45 @@ def module(self) -> None | ModuleType:
result = inspect.getmodule(fn_)
return result

@classmethod
def lookup_from_module(cls, name: str, executor: Any, module: ModuleType) -> Symbol: # For unpickling
if module not in sys.modules:
raise RuntimeError(f"Cannot find module {module} for symbol {name}.")

@staticmethod
def lookup_symbol(name: str, executor: Any, module: ModuleType) -> Symbol: # for unpickling
if executor is None:
if module not in sys.modules:
raise RuntimeError(f"Cannot find module {module} for symbol {name}.")
not_found = object()
sym = getattr(sys.modules[module], name, not_found)
if sym is not_found:
raise RuntimeError(f"Could not find symbol {name} in module {module}.")
assert isinstance(sym, Symbol), (name, module, type(sym), sym)
return sym
assert isinstance(sym, Symbol), f"lookup {module}.{name} gave object of type {type(sym)} instead of Symbol"
else:
# Try to find the executor in all_executors
import thunder.extend

executors = thunder.extend.get_all_executors()

for ex in executors:
implmap = ex.implmap.values()

for key, info in implmap:
assert isinstance(key.id, str)
if key.id == name:
if (
impl.symbol is not None
and module is not None
and impl.module is not None
and module != impl.module
):
continue
return lookup_from_module(name, ex, module)

raise ValueError(f"Could not find an executor for symbol {name} from module {module.__qualname__}.")

def __reduce__(self): # For pickling
if self.module is None:
raise ValueError("Cannot serialize a symbol without a module.")

if hasattr(self, "_python_name") and not self._python_name is None:
name = self._python_name
import thunder

ex = thunder.get_executor(executor)
sym = ex.opmap.get(name)

if sym is None:
raise RuntimeError(f"Could not find symbol {name} in executor {executor}.")
assert isinstance(
sym, Symbol
), f"lookup {name} in executor {executor} gave object of type {type(sym)} instead of Symbol"

return sym

def __reduce__(self): # for pickling
import thunder

if self.module is None and self.executor is None:
raise ValueError("Cannot serialize a symbol without a module and executor.")

if self.executor is None:
assert getattr(sys.modules[self.module.__name__], self.name, None) is self
else:
name = self.name
assert thunder.get_executor(self.executor.name).opmap.get(self.name) is self

return (
Symbol.lookup_from_module,
Symbol.lookup_symbol,
(
name,
self.name,
None if self.executor is None else self.executor.name,
None if self.module is None else self.module.__name__,
),
Expand Down
5 changes: 3 additions & 2 deletions thunder/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING

import thunder
from thunder.core.trace import from_trace
from thunder.core.transforms import bsym_list_to_dag, Node, toposort_bsym_dag, TOPOSORT_ORDER
from thunder.core.utils import check
Expand Down Expand Up @@ -105,7 +106,7 @@ def key(node: Node) -> int:

# TODO: This pass doesn't behave correctly if del nodes are present in the trace
check(
not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols),
not any(bsym.sym is thunder.core.prims.python_del for bsym in execution_trace.bound_symbols),
lambda: "Cannot sort execution trace with del nodes",
)
new_execution_trace.bound_symbols = toposort_bsym_dag(
Expand Down Expand Up @@ -165,7 +166,7 @@ def key(node: Node) -> int:

# TODO: This pass doesn't behave correctly if del nodes are present in the trace
check(
not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols),
not any(bsym.sym is thunder.core.prims.python_del for bsym in execution_trace.bound_symbols),
lambda: "Cannot sort execution trace with del nodes",
)
new_execution_trace.bound_symbols = toposort_bsym_dag(
Expand Down
13 changes: 10 additions & 3 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,16 +2785,23 @@ def foo2(x):
def test_serialize_trace():
import dill as pickle

def fn(a, b):
return a + b
def fn(a, b, l):
res = a + b
for t in l:
res = res + t
return res

tm = thunder.jit(fn)
a, b = torch.randn(2, 5, device=("cuda" if torch.cuda.is_available() else "cpu"))
tm(a, b)
tm(a, b, [a, b])
trace = thunder.last_traces(tm)[0]

assert str(pickle.loads(pickle.dumps(trace))) == str(trace)

prologue_trace = thunder.last_prologue_traces(tm)[0]

assert str(pickle.loads(pickle.dumps(prologue_trace))) == str(prologue_trace)


@pytest.mark.parametrize("requires_grad", (True, False))
def test_dataclass_output(requires_grad):
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def runtime_allocated_memory(dev):


def get_return_memory(bsym):
assert bsym.sym.name == "return"
assert bsym.sym is thunder.core.prims.python_return
return_tensors_name = set()
res = 0
for x in bsym.flat_proxy_args:
Expand Down

0 comments on commit 019557e

Please sign in to comment.