Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove constant_function for AutoEnzyme #74

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = [
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
]
version = "1.6.1"
version = "1.6.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
54 changes: 5 additions & 49 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,73 +39,29 @@ struct AutoDiffractor <: AbstractADType end
mode(::AutoDiffractor) = ForwardOrReverseMode()

"""
AutoEnzyme{M,constant_function}
AutoEnzyme{M}

Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoEnzyme(; mode=nothing, constant_function::Bool=false)

The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl.
For simple functions, `constant_function` should usually be set to `true`, which leads to increased performance.
However, in the case of closures or callable structs which contain differentiated data, `constant_function` should be set to `false` to ensure correctness (more details below).
AutoEnzyme(; mode=nothing)

# Fields

- `mode::M`: can be either

+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
+ `nothing` to choose the best mode automatically

# Notes

We now give several examples of functions.
For each one, we explain how `constant_function` should be set in order to compute the correct derivative with respect to the input `x`.

```julia
function f1(x)
return x[1]
end
```

The function `f1` is not a closure, it does not contain any data.
Thus `f1` can be differentiated with `AutoEnzyme(constant_function=true)` (although here setting `constant_function=false` would change neither correctness nor performance).

```julia
parameter = [0.0]
function f2(x)
return parameter[1] + x[1]
end
```

The function `f2` is a closure over `parameter`, but `parameter` is never modified based on the input `x`.
Thus, `f2` can be differentiated with `AutoEnzyme(constant_function=true)` (setting `constant_function=false` would not change correctness but would hinder performance).

```julia
cache = [0.0]
function f3(x)
cache[1] = x[1]
return cache[1] + x[1]
end
```

The function `f3` is a closure over `cache`, and `cache` is modified based on the input `x`.
That means `cache` cannot be treated as constant, since derivative values must be propagated through it.
Thus `f3` must be differentiated with `AutoEnzyme(constant_function=false)` (setting `constant_function=true` would make the result incorrect).
"""
struct AutoEnzyme{M, constant_function} <: AbstractADType
struct AutoEnzyme{M} <: AbstractADType
mode::M
end

function AutoEnzyme(mode::M; constant_function::Bool = false) where {M}
return AutoEnzyme{M, constant_function}(mode)
end

function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M}
return AutoEnzyme{M, constant_function}(mode)
function AutoEnzyme(; mode::M = nothing) where {M}
return AutoEnzyme{M}(mode)
end

mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
Expand Down
12 changes: 6 additions & 6 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,25 @@ end
@testset "AutoEnzyme" begin
ad = AutoEnzyme()
@test ad isa AbstractADType
@test ad isa AutoEnzyme{Nothing, false}
@test ad isa AutoEnzyme{Nothing}
@test mode(ad) isa ForwardOrReverseMode
@test ad.mode === nothing

ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true)
ad = AutoEnzyme(EnzymeCore.Forward)
@test ad isa AbstractADType
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true}
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
@test mode(ad) isa ForwardMode
@test ad.mode == EnzymeCore.Forward

ad = AutoEnzyme(; mode = EnzymeCore.Forward)
@test ad isa AbstractADType
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false}
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
@test mode(ad) isa ForwardMode
@test ad.mode == EnzymeCore.Forward

ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true)
ad = AutoEnzyme(; mode = EnzymeCore.Reverse)
@test ad isa AbstractADType
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true}
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)}
@test mode(ad) isa ReverseMode
@test ad.mode == EnzymeCore.Reverse
end
Expand Down
Loading