diff --git a/Project.toml b/Project.toml index 0a19f49..a60b58f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,21 +1,29 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.4.1" authors = ["Mike J Innes "] -version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[extensions] +OptimisersEnzymeCoreExt = "EnzymeCore" + [compat] ChainRulesCore = "1" +EnzymeCore = "0.8.5" Functors = "0.4.9, 0.5" Statistics = "1" Zygote = "0.6.40" -julia = "1.6" +julia = "1.10" [extras] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -23,4 +31,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "StaticArrays", "Zygote"] +test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"] diff --git a/ext/OptimisersEnzymeCoreExt.jl b/ext/OptimisersEnzymeCoreExt.jl new file mode 100644 index 0000000..89b5f9a --- /dev/null +++ b/ext/OptimisersEnzymeCoreExt.jl @@ -0,0 +1,56 @@ +module OptimisersEnzymeCoreExt + +import Optimisers: trainable, setup, update!, isnumeric, AbstractRule +import EnzymeCore: Duplicated, Const + +using Functors: fmapstructure + +println("loaded!") + +trainable(x::Duplicated) = (; val = x.val) + +""" + setup(rule::AbstractRule, model_grad::Duplicated) + +For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`. +""" +setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val) + +""" + update!(opt_state, model_grad::Duplicated) + +For use with Enzyme's `Duplicated`, which holds both a model/parameters +and the corresponding gradient. + +# Example + +```jldoctest +julia> using Optimisers, EnzymeCore + +julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4]) +Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0]) + +julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx +Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0]) + +julia> Optimisers.update!(st, x_dx) # mutates both arguments +model_grad.val = Float16[0.8887, 2.0, 3.445] + +julia> x_dx +Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0]) + +julia> st +Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443]) +``` +""" +function update!(opt_state, model_grad::Duplicated) + _, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad)) + nothing +end + +# This function strips the returned gradient to be Zygote-like: +_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) +_grad_or_nothing(::Const) = nothing +_grad_or_nothing(x) = isnumeric(x) ? x : nothing + +end