Skip to content

Commit

Permalink
refactor split inference tests for advi+distributionsad
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Nov 11, 2023
1 parent 3691f16 commit a063583
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 144 deletions.
208 changes: 64 additions & 144 deletions test/inference/advi_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,156 +3,76 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference_advi" begin
@testset "distributionsad" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64], # Currently only tested against Float64
(modelname, modelconstr) Dict(
:Normal=> normal_meanfield,
),
(objname, objective) Dict(
:ADVIClosedFormEntropy => ADVI(10),
:ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
@testset "inference_advi_distributionsad" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
:Normal=> normal_meanfield,
),
(objname, objective) Dict(
:ADVIClosedFormEntropy => ADVI(10),
:ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))
q₀_z = TuringDiagMvNormal(μ₀, diag(L₀))

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
μ = mean(q)
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))
q₀_z = TuringDiagMvNormal(μ₀, diag(L₀))

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = mean(q)
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end

@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = mean(q)
L = sqrt(cov(q))

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q)
L_repl = sqrt(cov(q))
@test μ == μ_repl
@test L == L_repl
end
@test Δλ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
end
end

@testset "inference_bijectors_advi" begin
@testset "distributionsad" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64], # Currently only tested against Float64
(modelname, modelconstr) Dict(
:NormalLogNormalMeanField => normallognormal_meanfield,
),
(objname, objective) Dict(
:ADVIClosedFormEntropy => ADVI(10),
:ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))

q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
q₀_z = Bijectors.transformed(q₀_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = mean(q.dist)
L = sqrt(cov(q.dist))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end

@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = mean(q.dist)
L = sqrt(cov(q.dist))

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q.dist)
L_repl = sqrt(cov(q.dist))
@test μ == μ_repl
@test L == L_repl
end
μ = mean(q)
L = sqrt(cov(q))

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q)
L_repl = sqrt(cov(q))
@test μ == μ_repl
@test L == L_repl
end
end
end

81 changes: 81 additions & 0 deletions test/inference/advi_distributionsad_bijectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@

const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference_advi_distributionsad_bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
:NormalLogNormalMeanField => normallognormal_meanfield,
),
(objname, objective) Dict(
:ADVIClosedFormEntropy => ADVI(10),
:ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))

q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
q₀_z = Bijectors.transformed(q₀_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = mean(q.dist)
L = sqrt(cov(q.dist))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end

@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = mean(q.dist)
L = sqrt(cov(q.dist))

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q.dist)
L_repl = sqrt(cov(q.dist))
@test μ == μ_repl
@test L == L_repl
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ include("interface/optimize.jl")
include("interface/advi.jl")

include("inference/advi_distributionsad.jl")
include("inference/advi_distributionsad_bijectors.jl")

0 comments on commit a063583

Please sign in to comment.