diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..e4fecbbe --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +Manifest.toml +.vscode/ diff --git a/src/interface.jl b/src/interface.jl index 9ec67e50..ae831906 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -50,7 +50,7 @@ function update(tree, x, x̄s...) end # default all rules to first order calls -apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx) +apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) """ isnumeric(x) -> Bool diff --git a/test/runtests.jl b/test/runtests.jl index 6e1bf28f..59f2cb0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,19 @@ struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) Optimisers.trainable(x::TwoThirds) = (a = x.a,) +struct DummyHigherOrder <: AbstractRule end + +Optimisers.init(::DummyHigherOrder, x::AbstractArray) = + (ones(eltype(x), size(x)), zero(x)) + +dummy_update_rule(st, p, dx, dx2) = @. p - (st[1] * dx + st[2] * dx2) +function Optimisers.apply!(::DummyHigherOrder, state, x, dx, dx2) + a, b = state + @.. dx = a * dx + b * dx2 + + return (a .+ 1, b .+ 1), dx +end + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin @@ -220,6 +233,33 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test_throws ArgumentError Optimisers.setup(AdamW(), m2) end + @testset "higher order interface" begin + w, b = rand(3, 4), rand(3) + + o = DummyHigherOrder() + psin = (w, b) + dxs = map(x -> rand(size(x)...), psin) + dx2s = map(x -> rand(size(x)...), psin) + stin = Optimisers.setup(o, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + + # hardcoded rule behavior for dummy rule + @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) + @test stout[1].state[1] == stin[1].state[1] .+ 1 + @test stout[2].state[2] == stin[2].state[2] .+ 1 + + # error if only given one derivative + @test_throws MethodError Optimisers.update(stin, psin, dxs) + + # first-order rules compose with second-order + ochain = OptimiserChain(Descent(0.1), o) + stin = Optimisers.setup(ochain, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) + end + end @testset verbose=true "Destructure" begin include("destructure.jl")