From 316b629eb965a591019b7149bbcf7fc72e613b9b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 21 Nov 2023 01:50:41 -0500 Subject: [PATCH] refactor rename advi to repgradelbo and not use bijectors directly --- Project.toml | 2 - src/AdvancedVI.jl | 9 +- src/objectives/elbo/advi.jl | 166 ------------------ src/objectives/elbo/repgradelbo.jl | 126 +++++++++++++ src/utils.jl | 3 + test/Project.toml | 2 - .../advi_distributionsad_bijectors.jl | 81 --------- ...nsad.jl => repgradelbo_distributionsad.jl} | 20 +-- test/interface/advi.jl | 55 ------ test/interface/optimize.jl | 22 ++- test/interface/repgradelbo.jl | 28 +++ test/models/normallognormal.jl | 65 ------- test/runtests.jl | 8 +- 13 files changed, 182 insertions(+), 405 deletions(-) delete mode 100644 src/objectives/elbo/advi.jl create mode 100644 src/objectives/elbo/repgradelbo.jl delete mode 100644 test/inference/advi_distributionsad_bijectors.jl rename test/inference/{advi_distributionsad.jl => repgradelbo_distributionsad.jl} (78%) delete mode 100644 test/interface/advi.jl create mode 100644 test/interface/repgradelbo.jl delete mode 100644 test/models/normallognormal.jl diff --git a/Project.toml b/Project.toml index 70041561..7799d505 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -36,7 +35,6 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1, 0.2" Accessors = "0.1" -Bijectors = "0.12, 0.13" ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.87" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index b1decc4a..bb5b6e85 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -11,7 +11,6 @@ using Functors using Optimisers using DocStringExtensions - using ProgressMeter using LinearAlgebra @@ -21,7 +20,6 @@ using ADTypes, DiffResults using ChainRulesCore using FillArrays -using Bijectors using StatsBase @@ -115,18 +113,17 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ """ function estimate_gradient! end -# ADVI-specific interfaces +# ELBO-specific interfaces abstract type AbstractEntropyEstimator end export - ADVI, + RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy -# entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") -include("objectives/elbo/advi.jl") +include("objectives/elbo/repgradelbo.jl") # Optimization Routine diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl deleted file mode 100644 index 98f8ae99..00000000 --- a/src/objectives/elbo/advi.jl +++ /dev/null @@ -1,166 +0,0 @@ - -""" - ADVI(n_samples; kwargs...) - -Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. -This computes the evidence lower-bound (ELBO) through the ADVI formulation: -```math -\\begin{aligned} -\\mathrm{ADVI}\\left(\\lambda\\right) -&\\triangleq -\\mathbb{E}_{\\eta \\sim q_{\\lambda}}\\left[ - \\log \\pi\\left( \\phi^{-1}\\left( \\eta \\right) \\right) - + - \\log \\lvert J_{\\phi^{-1}}\\left(\\eta\\right) \\rvert -\\right] -+ \\mathbb{H}\\left(q_{\\lambda}\\right), -\\end{aligned} -``` -where ``\\phi^{-1}`` is an "inverse bijector." - -# Arguments -- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. - -# Keyword Arguments -- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) - -# Requirements -- ``q_{\\lambda}`` implements `rand`. -- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. - -Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. - -# References -* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. -* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. -""" -struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - entropy ::EntropyEst - n_samples::Int -end - -ADVI( - n_samples::Int; - entropy ::AbstractEntropyEstimator = ClosedFormEntropy() -) = ADVI(entropy, n_samples) - -Base.show(io::IO, advi::ADVI) = - print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") - -maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop - -maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q - -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, mc_samples, q, q_stop) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, mc_samples, q_maybe_stop) -end - -function estimate_energy_with_samples(::ADVI, mc_samples::AbstractMatrix, prob) - mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples)) -end - -function estimate_energy_with_samples_bijector(::ADVI, mc_samples::AbstractMatrix, invbij, prob) - mean(eachcol(mc_samples)) do mc_sample - mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(invbij, mc_sample) - LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ - end -end - -function estimate_advi_maybe_stl_with_samples( - advi ::ADVI, - q ::ContinuousDistribution, - q_stop ::ContinuousDistribution, - mc_samples::AbstractMatrix, - prob -) - energy = estimate_energy_with_samples(advi, mc_samples, prob) - entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop) - energy + entropy -end - -function estimate_advi_maybe_stl_with_samples( - advi ::ADVI, - q_trans ::Bijectors.TransformedDistribution, - q_trans_stop::Bijectors.TransformedDistribution, - mc_samples ::AbstractMatrix, - prob -) - q = q_trans.dist - invbij = q_trans.transform - q_stop = q_trans_stop.dist - energy = estimate_energy_with_samples_bijector(advi, mc_samples, invbij, prob) - entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop) - energy + entropy -end - -rand_unconstrained( - rng ::Random.AbstractRNG, - q ::ContinuousDistribution, - n_samples::Int -) = rand(rng, q, n_samples) - -rand_unconstrained( - rng ::Random.AbstractRNG, - q ::Bijectors.TransformedDistribution, - n_samples::Int -) = rand(rng, q.dist, n_samples) - -function estimate_advi_maybe_stl(rng::Random.AbstractRNG, advi::ADVI, q, q_stop, prob) - mc_samples = rand_unconstrained(rng, q, advi.n_samples) - estimate_advi_maybe_stl_with_samples(advi, q, q_stop, mc_samples, prob) -end - -""" - estimate_objective([rng,] advi, q, prob; n_samples) - -Estimate the ELBO using the ADVI formulation. - -# Arguments -- `advi::ADVI`: ADVI objective. -- `q`: Variational approximation -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - -# Keyword Arguments -- `n_samples::Int = advi.n_samples`: Number of samples to be used to estimate the objective. - -# Returns -- `obj_est`: Estimate of the objective value. -""" -function estimate_objective( - rng ::Random.AbstractRNG, - advi ::ADVI, - q, - prob; - n_samples::Int = advi.n_samples -) - mc_samples = rand_unconstrained(rng, q, n_samples) - estimate_advi_maybe_stl_with_samples(advi, q, q, mc_samples, prob) -end - -estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) = - estimate_objective(Random.default_rng(), advi, q, prob; n_samples) - -function estimate_gradient!( - rng ::Random.AbstractRNG, - advi ::ADVI, - adbackend ::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult, - prob, - λ, - restructure, - est_state, -) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) - elbo = estimate_advi_maybe_stl(rng, advi, q, q_stop, prob) - -elbo - end - value_and_gradient!(adbackend, f, λ, out) - - nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - - out, nothing, stat -end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl new file mode 100644 index 00000000..09ba1a79 --- /dev/null +++ b/src/objectives/elbo/repgradelbo.jl @@ -0,0 +1,126 @@ + +""" + RepGradELBO(n_samples; kwargs...) + +Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014]. +This computes the evidence lower-bound (ELBO) through the formulation: +```math +\\begin{aligned} +\\mathrm{ELBO}\\left(\\lambda\\right) +&\\triangleq +\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ + \\log \\pi\\left(z\\right) +\\right] ++ \\mathbb{H}\\left(q_{\\lambda}\\right), +\\end{aligned} +``` + +# Arguments +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. + +# References +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML. +[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML. +[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR. +""" +struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + entropy ::EntropyEst + n_samples::Int +end + +RepGradELBO( + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy() +) = RepGradELBO(entropy, n_samples) + +Base.show(io::IO, obj::RepGradELBO) = + print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))") + +maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop + +maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q + +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + +function estimate_energy_with_samples(::RepGradELBO, samples, prob) + mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) +end + +function estimate_repgradelbo_maybe_stl_with_samples( + obj::RepGradELBO, q, q_stop, samples::AbstractMatrix, prob +) + energy = estimate_energy_with_samples(obj, samples, prob) + entropy = estimate_entropy_maybe_stl(obj.entropy, samples, q, q_stop) + energy + entropy +end + +function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELBO, q, q_stop, prob) + samples = rand(rng, q, obj.n_samples) + estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob) +end + +""" + estimate_objective([rng,] obj, q, prob; n_samples) + +Estimate the ELBO using the reparameterization gradient formulation. + +# Arguments +- `obj::RepGradELBO`: The ELBO objective. +- `q`: Variational approximation +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int = obj.n_samples`: Number of samples to be used to estimate the objective. + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + obj::RepGradELBO, + q, + prob; + n_samples::Int = obj.n_samples +) + samples = rand(rng, q, n_samples) + estimate_repgradelbo_maybe_stl_with_samples(obj, q, q, samples, prob) +end + +estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = + estimate_objective(Random.default_rng(), obj, q, prob; n_samples) + +function estimate_gradient!( + rng ::Random.AbstractRNG, + obj ::RepGradELBO, + adbackend::ADTypes.AbstractADType, + out ::DiffResults.MutableDiffResult, + prob, + λ, + restructure, + est_state, +) + q_stop = restructure(λ) + function f(λ′) + q = restructure(λ′) + elbo = estimate_repgradelbo_maybe_stl(rng, obj, q, q_stop, prob) + -elbo + end + value_and_gradient!(adbackend, f, λ, out) + + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + out, nothing, stat +end diff --git a/src/utils.jl b/src/utils.jl index 8dd7c37b..76637fa3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -21,3 +21,6 @@ function maybe_init_objective( haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure) end +eachsample(samples::AbstractMatrix) = eachcol(samples) + +eachsample(samples::AbstractVector) = samples diff --git a/test/Project.toml b/test/Project.toml index 490782cb..7d0bf2d2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -23,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1" -Bijectors = "0.13.6" Distributions = "0.25.100" DistributionsAD = "0.6.45" Enzyme = "0.11.7" diff --git a/test/inference/advi_distributionsad_bijectors.jl b/test/inference/advi_distributionsad_bijectors.jl deleted file mode 100644 index 29602fe7..00000000 --- a/test/inference/advi_distributionsad_bijectors.jl +++ /dev/null @@ -1,81 +0,0 @@ - -const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false - -using Test - -@testset "inference_advi_distributionsad_bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64, Float32], - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => ADVI(10), - :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), - ), - (adbackname, adbackend) ∈ Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), - #:Enzyme => AutoEnzyme(), - ) - - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = modelconstr(rng, realtype) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) - - q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - - μ = mean(q.dist) - L = sqrt(cov(q.dist)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ Δλ₀/T^(1/4) - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q, stats, _ = optimize( - rng, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) - - rng_repl = StableRNG(seed) - q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; - optimizer = Optimisers.Adam(realtype(η)), - show_progress = PROGRESS, - adbackend = adbackend, - ) - μ_repl = mean(q.dist) - L_repl = sqrt(cov(q.dist)) - @test μ == μ_repl - @test L == L_repl - end - end -end diff --git a/test/inference/advi_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl similarity index 78% rename from test/inference/advi_distributionsad.jl rename to test/inference/repgradelbo_distributionsad.jl index e82a9ec0..29cb2d83 100644 --- a/test/inference/advi_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -3,15 +3,15 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "inference_advi_distributionsad" begin +@testset "inference RepGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64, Float32], (modelname, modelconstr) ∈ Dict( :Normal=> normal_meanfield, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => ADVI(10), - :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), @@ -28,14 +28,14 @@ using Test T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) - q₀_z = TuringDiagMvNormal(μ₀, diag(L₀)) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + q0 = TuringDiagMvNormal(μ0, diag(L0)) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, @@ -53,7 +53,7 @@ using Test @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, @@ -63,7 +63,7 @@ using Test rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; + rng_repl, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, diff --git a/test/interface/advi.jl b/test/interface/advi.jl deleted file mode 100644 index 1df396e4..00000000 --- a/test/interface/advi.jl +++ /dev/null @@ -1,55 +0,0 @@ - -using Test - -@testset "advi" begin - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - @testset "with bijector" begin - modelstats = normallognormal_meanfield(rng, Float64) - - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - b⁻¹ = Bijectors.bijector(model) |> inverse - q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - obj = ADVI(10) - - rng = StableRNG(seed) - elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - - @testset "determinism" begin - rng = StableRNG(seed) - elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - @test elbo == elbo_ref - end - - @testset "default_rng" begin - elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4) - @test elbo ≈ elbo_ref rtol=0.1 - end - end - - @testset "without bijector" begin - modelstats = normal_meanfield(rng, Float64) - - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - q₀_z = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - - obj = ADVI(10) - rng = StableRNG(seed) - elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - - @testset "determinism" begin - rng = StableRNG(seed) - elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4) - @test elbo == elbo_ref - end - - @testset "default_rng" begin - elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4) - @test elbo ≈ elbo_ref rtol=0.1 - end - end -end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 3459b4c3..6e69616b 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -1,27 +1,25 @@ using Test -@testset "optimize" begin +@testset "interface optimize" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) T = 1000 - modelstats = normallognormal_meanfield(rng, Float64) + modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations - b⁻¹ = Bijectors.bijector(model) |> inverse - q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) - obj = ADVI(10) + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + obj = RepGradELBO(10) adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) rng = StableRNG(seed) q_ref, stats_ref, _ = optimize( - rng, model, obj, q₀_z, T; + rng, model, obj, q0, T; optimizer, show_progress = false, adbackend, @@ -30,13 +28,13 @@ using Test @testset "default_rng" begin optimize( - model, obj, q₀_z, T; + model, obj, q0, T; optimizer, show_progress = false, adbackend, ) - λ₀, re = Optimisers.destructure(q₀_z) + λ₀, re = Optimisers.destructure(q0) optimize( model, obj, re, λ₀, T; optimizer, @@ -46,7 +44,7 @@ using Test end @testset "restructure" begin - λ₀, re = Optimisers.destructure(q₀_z) + λ₀, re = Optimisers.destructure(q0) rng = StableRNG(seed) λ, stats, _ = optimize( @@ -67,7 +65,7 @@ using Test rng = StableRNG(seed) _, stats, _ = optimize( - rng, model, obj, q₀_z, T; + rng, model, obj, q0, T; show_progress = false, adbackend, callback @@ -82,7 +80,7 @@ using Test T_last = T - T_first q_first, _, state = optimize( - rng, model, obj, q₀_z, T_first; + rng, model, obj, q0, T_first; optimizer, show_progress = false, adbackend diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl new file mode 100644 index 00000000..61ff0111 --- /dev/null +++ b/test/interface/repgradelbo.jl @@ -0,0 +1,28 @@ + +using Test + +@testset "interface RepGradELBO" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + + obj = RepGradELBO(10) + rng = StableRNG(seed) + elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = estimate_objective(obj, q0, model; n_samples=10^4) + @test elbo ≈ elbo_ref rtol=0.1 + end +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl deleted file mode 100644 index c2cb2b0e..00000000 --- a/test/models/normallognormal.jl +++ /dev/null @@ -1,65 +0,0 @@ - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L_y = tril(I + ones(realtype, n_dims, n_dims))/2 - Σ_y = L_y*L_y' |> Hermitian - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) - - Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) - Σ[1,1] = σ_x^2 - Σ[2:end,2:end] = Σ_y - Σ = Σ |> Hermitian - - μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L - - TestModel(model, μ, L, n_dims+1, false) -end - -function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) - - μ = vcat(μ_x, μ_y) - L = vcat(σ_x, σ_y) |> Diagonal - - TestModel(model, μ, L, n_dims+1, true) -end diff --git a/test/runtests.jl b/test/runtests.jl index 757a931d..a855541c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,6 @@ using Functors using DistributionsAD @functor TuringDiagMvNormal -using Bijectors using LogDensityProblems using Optimisers using ADTypes @@ -30,14 +29,11 @@ struct TestModel{M,L,S} n_dims::Int is_meanfield::Bool end - -include("models/normallognormal.jl") include("models/normal.jl") # Tests include("interface/ad.jl") include("interface/optimize.jl") -include("interface/advi.jl") +include("interface/repgradelbo.jl") -include("inference/advi_distributionsad.jl") -include("inference/advi_distributionsad_bijectors.jl") +include("inference/repgradelbo_distributionsad.jl")