diff --git a/Project.toml b/Project.toml index 66c062ed..e02c589f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.2.0" +version = "0.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/destructure.jl b/src/destructure.jl index 3ace52ec..2b91983d 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -131,7 +131,7 @@ function _grad!(x, dx, off, flat::AbstractVector) flat end function _grad!(x, dx, off::Integer, flat::AbstractVector) - @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes + @views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes flat end _grad!(x, dx::Zero, off, flat::AbstractVector) = dx diff --git a/test/destructure.jl b/test/destructure.jl index 40c4360c..043315b3 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -2,12 +2,17 @@ m1 = collect(1:3.0) m2 = (collect(1:3.0), collect(4:6.0)) m3 = (x = m1, y = sin, z = collect(4:6.0)) + m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) m6 = (a = m1, b = [4.0 + im], c = m1) + m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] +mat = Float32[4 6; 5 7] +m9 = (a = m1, b = mat, c = [mat, m1]) + @testset "flatten & rebuild" begin @test destructure(m1)[1] isa Vector{Float64} @test destructure(m1)[1] == 1:3 @@ -16,6 +21,7 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] @test destructure(m4)[1] == 1:6 @test destructure(m5)[1] == vcat(1:6, 4:6) @test destructure(m6)[1] == vcat(1:3, 4 + im) + @test destructure(m9)[1] == 1:7 @test destructure(m1)[2](7:9) == [7,8,9] @test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9]) @@ -45,6 +51,10 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] @test m8′[2].b.y === false @test m8′[3][1] == [5.0] + m9′ = destructure(m9)[2](10:10:70) + @test m9′.b === m9′.c[1] + @test m9′.b isa Matrix{Float32} + # errors @test_throws Exception destructure(m7)[2]([10,20]) @test_throws Exception destructure(m7)[2]([10,20,30,40]) @@ -71,6 +81,9 @@ end @test g8[2].b.x == [8] @test g8[3] == [[10.0]] + g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] + @test g9.c === nothing + @testset "second derivative" begin @test gradient([1,2,3.0]) do v sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) @@ -119,6 +132,9 @@ end @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + re9 = destructure(m9)[2] + @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] + @testset "second derivative" begin @test_broken gradient(collect(1:6.0)) do y sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])