diff --git a/test/runtests.jl b/test/runtests.jl index 88eaf976..609bd499 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,8 @@ using Optimisers: @.., @lazy Random.seed!(1) +# Fake "models" for testing + struct Foo; x; y; end Functors.@functor Foo Optimisers.trainable(x::Foo) = (x.y, x.x) @@ -16,6 +18,8 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) mutable struct MutTwo; x; y; end Functors.@functor MutTwo +# Simple rules for testing + struct DummyHigherOrder <: AbstractRule end Optimisers.init(::DummyHigherOrder, x::AbstractArray) = (ones(eltype(x), size(x)), zero(x)) @@ -227,23 +231,6 @@ end @test_throws MethodError Optimisers.update(sm, m) end - @testset "2nd order gradient" begin - m = (α = ([1.0], sin), γ = Float32[4,3,2]) - - # Special rule which requires this: - s = Optimisers.setup(BiRule(), m) - g = (α = ([0.1], ZeroTangent()), γ = [1,10,100],) - s1, m1 = Optimisers.update(s, m, g, g) - @test m1.α[1] == [0.9] - @test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g)) - - # Ordinary rule which doesn't need it: - s2 = Optimisers.setup(Adam(), m) - s3, m3 = Optimisers.update(s2, m, g) - s4, m4 = Optimisers.update(s2, m, g, g) - @test m3.γ == m4.γ - end - @testset "broadcasting macros" begin x = [1.0, 2.0]; y = [3,4]; z = [5,6] @test (@lazy x + y * z) isa Broadcast.Broadcasted @@ -365,34 +352,54 @@ end @test model2.a === model2.b # tie of MutTwo structs is restored @test model2.a !== model2.c # but a new tie is not created end - end + end # tied weights + + @testset "2nd-order interface" begin + @testset "BiRule" begin + m = (α = ([1.0], sin), γ = Float32[4,3,2]) + + # Special rule which requires this: + s = Optimisers.setup(BiRule(), m) + g = (α = ([0.1], ZeroTangent()), γ = [1,10,100],) + s1, m1 = Optimisers.update(s, m, g, g) + @test m1.α[1] == [0.9] + @test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g)) + + # Ordinary rule which doesn't need it: + s2 = Optimisers.setup(Adam(), m) + s3, m3 = Optimisers.update(s2, m, g) + s4, m4 = Optimisers.update(s2, m, g, g) + @test m3.γ == m4.γ + end - @testset "higher order interface" begin - w, b = rand(3, 4), rand(3) - - o = DummyHigherOrder() - psin = (w, b) - dxs = map(x -> rand(size(x)...), psin) - dx2s = map(x -> rand(size(x)...), psin) - stin = Optimisers.setup(o, psin) - stout, psout = Optimisers.update(stin, psin, dxs, dx2s) - - # hardcoded rule behavior for dummy rule - @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) - @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) - @test stout[1].state[1] == stin[1].state[1] .+ 1 - @test stout[2].state[2] == stin[2].state[2] .+ 1 - - # error if only given one derivative - @test_throws MethodError Optimisers.update(stin, psin, dxs) - - # first-order rules compose with second-order - ochain = OptimiserChain(Descent(0.1), o) - stin = Optimisers.setup(ochain, psin) - stout, psout = Optimisers.update(stin, psin, dxs, dx2s) - @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) - @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) - end + @testset "DummyHigherOrder" begin + w, b = rand(3, 4), rand(3) + + o = DummyHigherOrder() + psin = (w, b) + dxs = map(x -> rand(size(x)...), psin) + dx2s = map(x -> rand(size(x)...), psin) + stin = Optimisers.setup(o, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + + # hardcoded rule behavior for dummy rule + @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) + @test stout[1].state[1] == stin[1].state[1] .+ 1 + @test stout[2].state[2] == stin[2].state[2] .+ 1 + + # error if only given one derivative + @test_throws MethodError Optimisers.update(stin, psin, dxs) + + # first-order rules compose with second-order + ochain = OptimiserChain(Descent(0.1), o) + stin = Optimisers.setup(ochain, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) + end + end # 2nd-order + end end @testset verbose=true "Destructure" begin