Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add defer_within_autodiff to EnzymeInterpreter #2254

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@
reverse_rules::Bool
inactive_rules::Bool
broadcast_rewrite::Bool

# When false, leave the check for within_autodiff to the handler.
within_autodiff_rewrite::Bool

handler::T
end

Expand Down Expand Up @@ -169,6 +173,7 @@
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
within_autodiff_rewrite::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()
Expand Down Expand Up @@ -229,6 +234,7 @@
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool,
within_autodiff_rewrite::Bool,
handler
)
end
Expand All @@ -240,8 +246,42 @@
mode::API.CDerivativeMode,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
within_autodiff_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)

function EnzymeInterpreter(interp::EnzymeInterpreter;

Check warning on line 253 in src/compiler/interpreter.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/interpreter.jl#L253

Added line #L253 was not covered by tests
cache_or_token = (@static if HAS_INTEGRATED_CACHE
interp.token
else
interp.code_cache
end),
mt = interp.method_table,
local_cache = interp.local_cache,
world = interp.world,
inf_params = interp.inf_params,
opt_params = interp.opt_params,
forward_rules = interp.forward_rules,
reverse_rules = interp.reverse_rules,
inactive_rules = interp.inactive_rules,
broadcast_rewrite = interp.broadcast_rewrite,
within_autodiff_rewrite = interp.within_autodiff_rewrite,
handler = interp.handler)
return EnzymeInterpreter(

Check warning on line 270 in src/compiler/interpreter.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/interpreter.jl#L270

Added line #L270 was not covered by tests
cache_or_token,
mt,
local_cache,
world,
inf_params,
opt_params,
forward_rules,
reverse_rules,
inactive_rules,
broadcast_rewrite,
within_autodiff_rewrite,
handler
)
end

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
Expand Down Expand Up @@ -909,7 +949,7 @@

(; fargs, argtypes) = arginfo

if f === Enzyme.within_autodiff
if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
Expand Down
Loading