From 5d328cd30185bc2d196dbb4c1dd3360e50a2c328 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 21 Nov 2024 17:55:57 +0900 Subject: [PATCH] update bsym representing `torch.autograd.Funcion` Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index cf186b1ea..d22abf32d 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -766,6 +766,21 @@ def grad_transform(*args, **kwargs): execution_transform=core_of_forward, grad_transform=grad_transform, ) + + bsym_of_func: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1] + import_ctx, call_ctx, object_ctx = {}, {}, {} + for bsym in trace_of_fwd.bound_symbols: + for d_to_update, src in zip((import_ctx, call_ctx, object_ctx), bsym.gather_ctxs()): + d_to_update.update(src) + if import_ctx: + bsym_of_func._import_ctx.update(import_ctx) + if call_ctx: + if bsym_of_func._call_ctx is not None: + bsym_of_func._call_ctx = call_ctx + else: + bsym_of_func._call_ctx.update(call_ctx) + if object_ctx: + bsym_of_func._object_ctx.update(object_ctx) return forward_result