Skip to content

Commit

Permalink
add Duplicated methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 8, 2024
1 parent 38c9d62 commit fac710f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.4.1"
authors = ["Mike J Innes <[email protected]>"]
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
OptimisersEnzymeCoreExt = "EnzymeCore"

[compat]
ChainRulesCore = "1"
EnzymeCore = "0.8.5"
Functors = "0.4.9, 0.5"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
julia = "1.10"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"]
56 changes: 56 additions & 0 deletions ext/OptimisersEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module OptimisersEnzymeCoreExt

import Optimisers: trainable, setup, update!, isnumeric, AbstractRule
import EnzymeCore: Duplicated, Const

using Functors: fmapstructure

println("loaded!")

trainable(x::Duplicated) = (; val = x.val)

"""
setup(rule::AbstractRule, model_grad::Duplicated)
For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`.
"""
setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val)

"""
update!(opt_state, model_grad::Duplicated)
For use with Enzyme's `Duplicated`, which holds both a model/parameters
and the corresponding gradient.
# Example
```jldoctest
julia> using Optimisers, EnzymeCore
julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0])
julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0])
julia> Optimisers.update!(st, x_dx) # mutates both arguments
model_grad.val = Float16[0.8887, 2.0, 3.445]
julia> x_dx
Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0])
julia> st
Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443])
```
"""
function update!(opt_state, model_grad::Duplicated)
_, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad))
nothing
end

# This function strips the returned gradient to be Zygote-like:
_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = isnumeric(x) ? x : nothing

end

0 comments on commit fac710f

Please sign in to comment.