From aa6ef9cdfc80505665aebd0128adc8470d97bf97 Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx@users.noreply.github.com>
Date: Tue, 7 Jan 2025 10:45:01 +0100
Subject: [PATCH 1/2] add `defer_within_autodiff` to EnzymeInterpreter in order
 for `within_autodiff` to no return true during Reactant compilation. When
 this flag is true, `interp.handler` is responsible for handling
 within_autodiff, or to toggle defer_within_autodiff to false somewhere down
 the call chain.

---
 src/compiler/interpreter.jl | 44 +++++++++++++++++++++++++++++++++++--
 1 file changed, 42 insertions(+), 2 deletions(-)

diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl
index 3d5df1e266..ccec99f0b1 100644
--- a/src/compiler/interpreter.jl
+++ b/src/compiler/interpreter.jl
@@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
     reverse_rules::Bool
     inactive_rules::Bool
     broadcast_rewrite::Bool
+
+    # When true, leave the check for within_autodiff to the handler.
+    defer_within_autodiff::Bool
+
     handler::T
 end
 
@@ -169,6 +173,7 @@ function EnzymeInterpreter(
     reverse_rules::Bool,
     inactive_rules::Bool,
     broadcast_rewrite::Bool = true,
+    defer_within_autodiff::Bool = false,
     handler = nothing
 )
     @assert world <= Base.get_world_counter()
@@ -229,6 +234,7 @@ function EnzymeInterpreter(
         reverse_rules::Bool,
         inactive_rules::Bool,
         broadcast_rewrite::Bool,
+        defer_within_autodiff::Bool,
         handler
     )
 end
@@ -240,8 +246,42 @@ EnzymeInterpreter(
     mode::API.CDerivativeMode,
     inactive_rules::Bool,
     broadcast_rewrite::Bool = true,
+    defer_within_autodiff::Bool = false,
     handler = nothing
-) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)
+) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler)
+
+function EnzymeInterpreter(interp::EnzymeInterpreter;
+    cache_or_token = (@static if HAS_INTEGRATED_CACHE
+        interp.token
+    else
+        interp.code_cache
+    end),
+    mt = interp.method_table,
+    local_cache = interp.local_cache,
+    world = interp.world,
+    inf_params = interp.inf_params,
+    opt_params = interp.opt_params,
+    forward_rules = interp.forward_rules,
+    reverse_rules = interp.reverse_rules,
+    inactive_rules = interp.inactive_rules,
+    broadcast_rewrite = interp.broadcast_rewrite,
+    defer_within_autodiff = interp.defer_within_autodiff,
+    handler = interp.handler)
+    return EnzymeInterpreter(
+        cache_or_token,
+        mt,
+        local_cache,
+        world,
+        inf_params,
+        opt_params,
+        forward_rules,
+        reverse_rules,
+        inactive_rules,
+        broadcast_rewrite,
+        defer_within_autodiff,
+        handler
+    )
+end
 
 Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
 Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
@@ -909,7 +949,7 @@ function abstract_call_known(
 
     (; fargs, argtypes) = arginfo
 
-    if f === Enzyme.within_autodiff
+    if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff
         if length(argtypes) != 1
             @static if VERSION < v"1.11.0-"
                 return CallMeta(Union{}, Effects(), NoCallInfo())

From 9f540683cedcd61850bb6d2723bf47ce7f1c3189 Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx@users.noreply.github.com>
Date: Wed, 8 Jan 2025 13:10:30 +0100
Subject: [PATCH 2/2] `!defer_within_autodiff` -> `within_autodiff_rewrite`

---
 src/compiler/interpreter.jl | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl
index ccec99f0b1..84c26039b1 100644
--- a/src/compiler/interpreter.jl
+++ b/src/compiler/interpreter.jl
@@ -132,8 +132,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
     inactive_rules::Bool
     broadcast_rewrite::Bool
 
-    # When true, leave the check for within_autodiff to the handler.
-    defer_within_autodiff::Bool
+    # When false, leave the check for within_autodiff to the handler.
+    within_autodiff_rewrite::Bool
 
     handler::T
 end
@@ -173,7 +173,7 @@ function EnzymeInterpreter(
     reverse_rules::Bool,
     inactive_rules::Bool,
     broadcast_rewrite::Bool = true,
-    defer_within_autodiff::Bool = false,
+    within_autodiff_rewrite::Bool = true,
     handler = nothing
 )
     @assert world <= Base.get_world_counter()
@@ -234,7 +234,7 @@ function EnzymeInterpreter(
         reverse_rules::Bool,
         inactive_rules::Bool,
         broadcast_rewrite::Bool,
-        defer_within_autodiff::Bool,
+        within_autodiff_rewrite::Bool,
         handler
     )
 end
@@ -246,9 +246,9 @@ EnzymeInterpreter(
     mode::API.CDerivativeMode,
     inactive_rules::Bool,
     broadcast_rewrite::Bool = true,
-    defer_within_autodiff::Bool = false,
+    within_autodiff_rewrite::Bool = true,
     handler = nothing
-) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler)
+) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)
 
 function EnzymeInterpreter(interp::EnzymeInterpreter;
     cache_or_token = (@static if HAS_INTEGRATED_CACHE
@@ -265,7 +265,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
     reverse_rules = interp.reverse_rules,
     inactive_rules = interp.inactive_rules,
     broadcast_rewrite = interp.broadcast_rewrite,
-    defer_within_autodiff = interp.defer_within_autodiff,
+    within_autodiff_rewrite = interp.within_autodiff_rewrite,
     handler = interp.handler)
     return EnzymeInterpreter(
         cache_or_token,
@@ -278,7 +278,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
         reverse_rules,
         inactive_rules,
         broadcast_rewrite,
-        defer_within_autodiff,
+        within_autodiff_rewrite,
         handler
     )
 end
@@ -949,7 +949,7 @@ function abstract_call_known(
 
     (; fargs, argtypes) = arginfo
 
-    if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff
+    if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff
         if length(argtypes) != 1
             @static if VERSION < v"1.11.0-"
                 return CallMeta(Union{}, Effects(), NoCallInfo())