Skip to content

Commit

Permalink
group the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 12, 2022
1 parent d13e52a commit 3bca907
Showing 1 changed file with 51 additions and 44 deletions.
95 changes: 51 additions & 44 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using Optimisers: @.., @lazy

Random.seed!(1)

# Fake "models" for testing

struct Foo; x; y; end
Functors.@functor Foo
Optimisers.trainable(x::Foo) = (x.y, x.x)
Expand All @@ -16,6 +18,8 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
mutable struct MutTwo; x; y; end
Functors.@functor MutTwo

# Simple rules for testing

struct DummyHigherOrder <: AbstractRule end
Optimisers.init(::DummyHigherOrder, x::AbstractArray) =
(ones(eltype(x), size(x)), zero(x))
Expand Down Expand Up @@ -227,23 +231,6 @@ end
@test_throws MethodError Optimisers.update(sm, m)
end

@testset "2nd order gradient" begin
m == ([1.0], sin), γ = Float32[4,3,2])

# Special rule which requires this:
s = Optimisers.setup(BiRule(), m)
g == ([0.1], ZeroTangent()), γ = [1,10,100],)
s1, m1 = Optimisers.update(s, m, g, g)
@test m1.α[1] == [0.9]
@test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g))

# Ordinary rule which doesn't need it:
s2 = Optimisers.setup(Adam(), m)
s3, m3 = Optimisers.update(s2, m, g)
s4, m4 = Optimisers.update(s2, m, g, g)
@test m3.γ == m4.γ
end

@testset "broadcasting macros" begin
x = [1.0, 2.0]; y = [3,4]; z = [5,6]
@test (@lazy x + y * z) isa Broadcast.Broadcasted
Expand Down Expand Up @@ -365,34 +352,54 @@ end
@test model2.a === model2.b # tie of MutTwo structs is restored
@test model2.a !== model2.c # but a new tie is not created
end
end
end # tied weights

@testset "2nd-order interface" begin
@testset "BiRule" begin
m == ([1.0], sin), γ = Float32[4,3,2])

# Special rule which requires this:
s = Optimisers.setup(BiRule(), m)
g == ([0.1], ZeroTangent()), γ = [1,10,100],)
s1, m1 = Optimisers.update(s, m, g, g)
@test m1.α[1] == [0.9]
@test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g))

# Ordinary rule which doesn't need it:
s2 = Optimisers.setup(Adam(), m)
s3, m3 = Optimisers.update(s2, m, g)
s4, m4 = Optimisers.update(s2, m, g, g)
@test m3.γ == m4.γ
end

@testset "higher order interface" begin
w, b = rand(3, 4), rand(3)

o = DummyHigherOrder()
psin = (w, b)
dxs = map(x -> rand(size(x)...), psin)
dx2s = map(x -> rand(size(x)...), psin)
stin = Optimisers.setup(o, psin)
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)

# hardcoded rule behavior for dummy rule
@test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1])
@test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2])
@test stout[1].state[1] == stin[1].state[1] .+ 1
@test stout[2].state[2] == stin[2].state[2] .+ 1

# error if only given one derivative
@test_throws MethodError Optimisers.update(stin, psin, dxs)

# first-order rules compose with second-order
ochain = OptimiserChain(Descent(0.1), o)
stin = Optimisers.setup(ochain, psin)
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
@test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1])
@test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2])
end
@testset "DummyHigherOrder" begin
w, b = rand(3, 4), rand(3)

o = DummyHigherOrder()
psin = (w, b)
dxs = map(x -> rand(size(x)...), psin)
dx2s = map(x -> rand(size(x)...), psin)
stin = Optimisers.setup(o, psin)
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)

# hardcoded rule behavior for dummy rule
@test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1])
@test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2])
@test stout[1].state[1] == stin[1].state[1] .+ 1
@test stout[2].state[2] == stin[2].state[2] .+ 1

# error if only given one derivative
@test_throws MethodError Optimisers.update(stin, psin, dxs)

# first-order rules compose with second-order
ochain = OptimiserChain(Descent(0.1), o)
stin = Optimisers.setup(ochain, psin)
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
@test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1])
@test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2])
end
end # 2nd-order
end

end
@testset verbose=true "Destructure" begin
Expand Down

0 comments on commit 3bca907

Please sign in to comment.