From aa6ef9cdfc80505665aebd0128adc8470d97bf97 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:45:01 +0100 Subject: [PATCH 1/2] add `defer_within_autodiff` to EnzymeInterpreter in order for `within_autodiff` to no return true during Reactant compilation. When this flag is true, `interp.handler` is responsible for handling within_autodiff, or to toggle defer_within_autodiff to false somewhere down the call chain. --- src/compiler/interpreter.jl | 44 +++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 3d5df1e266..ccec99f0b1 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter reverse_rules::Bool inactive_rules::Bool broadcast_rewrite::Bool + + # When true, leave the check for within_autodiff to the handler. + defer_within_autodiff::Bool + handler::T end @@ -169,6 +173,7 @@ function EnzymeInterpreter( reverse_rules::Bool, inactive_rules::Bool, broadcast_rewrite::Bool = true, + defer_within_autodiff::Bool = false, handler = nothing ) @assert world <= Base.get_world_counter() @@ -229,6 +234,7 @@ function EnzymeInterpreter( reverse_rules::Bool, inactive_rules::Bool, broadcast_rewrite::Bool, + defer_within_autodiff::Bool, handler ) end @@ -240,8 +246,42 @@ EnzymeInterpreter( mode::API.CDerivativeMode, inactive_rules::Bool, broadcast_rewrite::Bool = true, + defer_within_autodiff::Bool = false, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler) + +function EnzymeInterpreter(interp::EnzymeInterpreter; + cache_or_token = (@static if HAS_INTEGRATED_CACHE + interp.token + else + interp.code_cache + end), + mt = interp.method_table, + local_cache = interp.local_cache, + world = interp.world, + inf_params = interp.inf_params, + opt_params = interp.opt_params, + forward_rules = interp.forward_rules, + reverse_rules = interp.reverse_rules, + inactive_rules = interp.inactive_rules, + broadcast_rewrite = interp.broadcast_rewrite, + defer_within_autodiff = interp.defer_within_autodiff, + handler = interp.handler) + return EnzymeInterpreter( + cache_or_token, + mt, + local_cache, + world, + inf_params, + opt_params, + forward_rules, + reverse_rules, + inactive_rules, + broadcast_rewrite, + defer_within_autodiff, + handler + ) +end Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params @@ -909,7 +949,7 @@ function abstract_call_known( (; fargs, argtypes) = arginfo - if f === Enzyme.within_autodiff + if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff if length(argtypes) != 1 @static if VERSION < v"1.11.0-" return CallMeta(Union{}, Effects(), NoCallInfo()) From 9f540683cedcd61850bb6d2723bf47ce7f1c3189 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:10:30 +0100 Subject: [PATCH 2/2] `!defer_within_autodiff` -> `within_autodiff_rewrite` --- src/compiler/interpreter.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ccec99f0b1..84c26039b1 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -132,8 +132,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter inactive_rules::Bool broadcast_rewrite::Bool - # When true, leave the check for within_autodiff to the handler. - defer_within_autodiff::Bool + # When false, leave the check for within_autodiff to the handler. + within_autodiff_rewrite::Bool handler::T end @@ -173,7 +173,7 @@ function EnzymeInterpreter( reverse_rules::Bool, inactive_rules::Bool, broadcast_rewrite::Bool = true, - defer_within_autodiff::Bool = false, + within_autodiff_rewrite::Bool = true, handler = nothing ) @assert world <= Base.get_world_counter() @@ -234,7 +234,7 @@ function EnzymeInterpreter( reverse_rules::Bool, inactive_rules::Bool, broadcast_rewrite::Bool, - defer_within_autodiff::Bool, + within_autodiff_rewrite::Bool, handler ) end @@ -246,9 +246,9 @@ EnzymeInterpreter( mode::API.CDerivativeMode, inactive_rules::Bool, broadcast_rewrite::Bool = true, - defer_within_autodiff::Bool = false, + within_autodiff_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler) function EnzymeInterpreter(interp::EnzymeInterpreter; cache_or_token = (@static if HAS_INTEGRATED_CACHE @@ -265,7 +265,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter; reverse_rules = interp.reverse_rules, inactive_rules = interp.inactive_rules, broadcast_rewrite = interp.broadcast_rewrite, - defer_within_autodiff = interp.defer_within_autodiff, + within_autodiff_rewrite = interp.within_autodiff_rewrite, handler = interp.handler) return EnzymeInterpreter( cache_or_token, @@ -278,7 +278,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter; reverse_rules, inactive_rules, broadcast_rewrite, - defer_within_autodiff, + within_autodiff_rewrite, handler ) end @@ -949,7 +949,7 @@ function abstract_call_known( (; fargs, argtypes) = arginfo - if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff + if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff if length(argtypes) != 1 @static if VERSION < v"1.11.0-" return CallMeta(Union{}, Effects(), NoCallInfo())