Skip to content

Commit

Permalink
Merge pull request #59 from mcabbott/vecbug
Browse files Browse the repository at this point in the history
Missing `vec` in gradient of `destructure`
  • Loading branch information
CarloLucibello authored Mar 6, 2022
2 parents 1e34fa2 + 140499e commit 0e70295
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.2.0"
version = "0.2.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function _grad!(x, dx, off, flat::AbstractVector)
flat
end
function _grad!(x, dx, off::Integer, flat::AbstractVector)
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
@views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes
flat
end
_grad!(x, dx::Zero, off, flat::AbstractVector) = dx
Expand Down
16 changes: 16 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
m1 = collect(1:3.0)
m2 = (collect(1:3.0), collect(4:6.0))
m3 = (x = m1, y = sin, z = collect(4:6.0))

m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
m6 = (a = m1, b = [4.0 + im], c = m1)

m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]

mat = Float32[4 6; 5 7]
m9 = (a = m1, b = mat, c = [mat, m1])

@testset "flatten & rebuild" begin
@test destructure(m1)[1] isa Vector{Float64}
@test destructure(m1)[1] == 1:3
Expand All @@ -16,6 +21,7 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
@test destructure(m4)[1] == 1:6
@test destructure(m5)[1] == vcat(1:6, 4:6)
@test destructure(m6)[1] == vcat(1:3, 4 + im)
@test destructure(m9)[1] == 1:7

@test destructure(m1)[2](7:9) == [7,8,9]
@test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9])
Expand Down Expand Up @@ -45,6 +51,10 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
@test m8′[2].b.y === false
@test m8′[3][1] == [5.0]

m9′ = destructure(m9)[2](10:10:70)
@test m9′.b === m9′.c[1]
@test m9′.b isa Matrix{Float32}

# errors
@test_throws Exception destructure(m7)[2]([10,20])
@test_throws Exception destructure(m7)[2]([10,20,30,40])
Expand All @@ -71,6 +81,9 @@ end
@test g8[2].b.x == [8]
@test g8[3] == [[10.0]]

g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
@test g9.c === nothing

@testset "second derivative" begin
@test gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
Expand Down Expand Up @@ -119,6 +132,9 @@ end
@test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

re9 = destructure(m9)[2]
@test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]

@testset "second derivative" begin
@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
Expand Down

0 comments on commit 0e70295

Please sign in to comment.