Skip to content

Commit

Permalink
update bsym representing torch.autograd.Funcion
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 21, 2024
1 parent 60f3ee1 commit 5d328cd
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,21 @@ def grad_transform(*args, **kwargs):
execution_transform=core_of_forward,
grad_transform=grad_transform,
)

bsym_of_func: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
for d_to_update, src in zip((import_ctx, call_ctx, object_ctx), bsym.gather_ctxs()):
d_to_update.update(src)
if import_ctx:
bsym_of_func._import_ctx.update(import_ctx)
if call_ctx:
if bsym_of_func._call_ctx is not None:
bsym_of_func._call_ctx = call_ctx
else:
bsym_of_func._call_ctx.update(call_ctx)
if object_ctx:
bsym_of_func._object_ctx.update(object_ctx)
return forward_result


Expand Down

0 comments on commit 5d328cd

Please sign in to comment.