Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 8, 2022
1 parent 3c71947 commit a792707
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,95 @@ end
@test n_iter == 3
end
end

@testset "Various destructure bugs" begin

@testset "issue 1601" begin
struct TwoDenses
dense::Dense
dense2::Dense
end
Flux.@functor TwoDenses

function (m::TwoDenses)(x)
out = m.dense(x)
end

model = TwoDenses(
Dense(3,1),
Dense(3,2)
)
p, re = Flux.destructure(model)

x = [1., 2., 3.]
y, back = Flux.Zygote.pullback((x, p) -> re(p)(x), x, p)

dy = [4.]
dx, dp = back(dy)
@test length(p) == length(dp)
end

@testset "issue 1727" begin
p, re = Flux.destructure(BatchNorm(3)) # 6 parameters, plus 6 non-trainable
@test length(p) == 6

x = rand(Float32, 3, 4)
y, back = Flux.pullback(x, p) do x, p
vec(re(p)(x))
end
@test_nowarn back(y)
b = back(y)

@test size(b[1]) == size(x)
@test size(b[2]) == size(p)
end

@testset "issue 1767" begin
struct Model{A}
a::A
b::A
end
Flux.@functor Model
(m::Model)(x) = m.a(x) .+ m.b(x)

d = Dense(1, 1)
x = rand(Float32, 1, 1)

# Sharing the parameters
model = Model(d, d)

# Works
g1 = Flux.gradient(() -> sum(model(x)), Flux.params(model))

p, re = Flux.destructure(model)
# Fails
g2 = Flux.gradient(p -> sum(re(p)(x)), p)

@test g2[1] vcat(g1[d.weight], g1[d.bias])
end

@testset "issue 1826" begin
struct Split{T} # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer
paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

n_input, n_batch, n_shared = 5, 13, 11
n_outputs = [3, 7]

data = rand(Float32, n_input, n_batch)
model = Chain(
Dense(n_input, n_shared),
Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2]))
)

pvec, re = Flux.destructure(model)
loss(x, idx, pv) = sum(abs2, re(pv)(x)[idx]) # loss wrt `idx`th output term

g = Flux.Zygote.ForwardDiff.gradient(pv -> loss(data, 1, pv), pvec)
@test g Flux.Zygote.gradient(pv -> loss(data, 1, pv), pvec)[1]
end

end

0 comments on commit a792707

Please sign in to comment.