Skip to content

Commit

Permalink
Merge pull request #110 from darsnack/higher-order-tests
Browse files Browse the repository at this point in the history
Add tests for higher order interface
  • Loading branch information
darsnack authored Sep 2, 2022
2 parents 5f51632 + 6133591 commit bf54f76
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Manifest.toml
.vscode/
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function update(tree, x, x̄s...)
end

# default all rules to first order calls
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)

"""
isnumeric(x) -> Bool
Expand Down
40 changes: 40 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

struct DummyHigherOrder <: AbstractRule end

Optimisers.init(::DummyHigherOrder, x::AbstractArray) =
(ones(eltype(x), size(x)), zero(x))

dummy_update_rule(st, p, dx, dx2) = @. p - (st[1] * dx + st[2] * dx2)
function Optimisers.apply!(::DummyHigherOrder, state, x, dx, dx2)
a, b = state
@.. dx = a * dx + b * dx2

return (a .+ 1, b .+ 1), dx
end

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin

Expand Down Expand Up @@ -220,6 +233,33 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test_throws ArgumentError Optimisers.setup(AdamW(), m2)
end

@testset "higher order interface" begin
w, b = rand(3, 4), rand(3)

o = DummyHigherOrder()
psin = (w, b)
dxs = map(x -> rand(size(x)...), psin)
dx2s = map(x -> rand(size(x)...), psin)
stin = Optimisers.setup(o, psin)
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)

# hardcoded rule behavior for dummy rule
@test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1])
@test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2])
@test stout[1].state[1] == stin[1].state[1] .+ 1
@test stout[2].state[2] == stin[2].state[2] .+ 1

# error if only given one derivative
@test_throws MethodError Optimisers.update(stin, psin, dxs)

# first-order rules compose with second-order
ochain = OptimiserChain(Descent(0.1), o)
stin = Optimisers.setup(ochain, psin)
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
@test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1])
@test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2])
end

end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down

0 comments on commit bf54f76

Please sign in to comment.