From 6b2db39cda1850be96fd54834a1b926ba6859752 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Tue, 26 Nov 2024 14:12:02 -0500 Subject: [PATCH] Injecting .saved() values into their Frames locals so you dont need to do .value ever --- src/nnsight/tracing/backends/base.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/nnsight/tracing/backends/base.py b/src/nnsight/tracing/backends/base.py index cb5e48cd..1d40311e 100755 --- a/src/nnsight/tracing/backends/base.py +++ b/src/nnsight/tracing/backends/base.py @@ -1,7 +1,9 @@ -from ..graph import Graph -from ..protocols import StopProtocol +import inspect import sys + from ...util import NNsightError +from ..graph import Graph, Proxy +from ..protocols import StopProtocol class Backend: @@ -13,12 +15,21 @@ def __call__(self, graph: Graph) -> None: class ExecutionBackend(Backend): + def __init__(self, injection: bool = True) -> None: + self.injection = injection + def __call__(self, graph: Graph) -> None: try: graph.nodes[-1].execute() + if self.injection: + frame = inspect.currentframe().f_back.f_back.f_back.f_back + for key, value in frame.f_locals.items(): + if isinstance(value, Proxy) and value.node.done: + frame.f_locals[key] = value.value + except StopProtocol.StopException: pass @@ -26,7 +37,9 @@ def __call__(self, graph: Graph) -> None: except NNsightError as e: if graph.debug: print(f"\n{e.traceback_content}") - print("During handling of the above exception, another exception occurred:\n") + print( + "During handling of the above exception, another exception occurred:\n" + ) print(f"{graph.nodes[e.node_id].meta_data['traceback']}") sys.tracebacklimit = 0 raise e from None