Skip to content

Commit

Permalink
use push|pop_scope
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 13, 2024
1 parent f349308 commit 8791f81
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,7 @@ def _generate_random_str_id() -> str:
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)
new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args]
old_scope = jit_ctx.computation_trace.scopes
fwd_bsyms = []
jit_ctx.computation_trace.scopes = [fwd_bsyms]
jit_ctx.computation_trace.push_scope([])

fwd_result = _interpret_call(fwd, *new_fwd_args)
if fwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
Expand All @@ -804,6 +802,7 @@ def _generate_random_str_id() -> str:

unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:]

fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()
producer_map = utils.producers(fwd_bsyms)
tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {}
for p in tree_flatten((output, saved_values))[0]:
Expand Down Expand Up @@ -838,7 +837,7 @@ def _generate_random_str_id() -> str:
_object_ctx=fwd_bsyms[0]._object_ctx,
_executor=fwd_bsyms[0]._executor,
)
old_scope[-1].append(bsym_of_custom_autograd_func)
jit_ctx.computation_trace.scopes[-1].append(bsym_of_custom_autograd_func)

# Define augmented fwd rule and backward rule on the fly.
augmented_fwd_trace = TraceCtx()
Expand All @@ -863,10 +862,8 @@ def augmented_fwd_rule(*args):

augmented_forward_impls[sym.id] = augmented_fwd_rule

bwd_bsyms = []
jit_ctx.computation_trace.scopes = [bwd_bsyms]
jit_ctx.computation_trace.push_scope([])
bwd_trace = TraceCtx()
bwd_trace.bound_symbols = bwd_bsyms

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), output))
bwd_args = (wrap_const(None),)
Expand All @@ -876,6 +873,7 @@ def augmented_fwd_rule(*args):
if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return bwd_result
unwrapped_bwd_result = unwrap(bwd_result)
bwd_trace.bound_symbols = jit_ctx.computation_trace.pop_scope()
bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=()))

bwd_si = SigInfo(f"bwd_{si.name}")
Expand All @@ -884,7 +882,6 @@ def augmented_fwd_rule(*args):
bwd_trace._siginfo = bwd_si
backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False)

jit_ctx.computation_trace.scopes = old_scope
return wrapped_output


Expand Down

0 comments on commit 8791f81

Please sign in to comment.